package org.simantics.scl.compiler.internal.elaboration.transformations;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.simantics.scl.compiler.common.names.Names;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.elaboration.contexts.EnvironmentalContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.EConstant;
import org.simantics.scl.compiler.elaboration.expressions.EEnforce;
import org.simantics.scl.compiler.elaboration.expressions.ERuleset;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.EWhen;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.Expressions;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.expressions.VariableProcedure;
import org.simantics.scl.compiler.elaboration.expressions.block.GuardStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.LetStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.Statement;
import org.simantics.scl.compiler.elaboration.query.QAtom;
import org.simantics.scl.compiler.elaboration.query.QConjunction;
import org.simantics.scl.compiler.elaboration.query.QMapping;
import org.simantics.scl.compiler.elaboration.query.Query;
import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
import org.simantics.scl.compiler.elaboration.rules.MappingRelation;
import org.simantics.scl.compiler.elaboration.rules.TransformationRule;
import org.simantics.scl.compiler.errors.ErrorLog;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;

import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;

public class TransformationBuilder {
    private final ErrorLog errorLog;
    private final TypingContext context;
    private final UnifiableFactory unifiableFactory;
    
    // Auxiliary
    static class Mapping {
        LocalRelation relation;
        Variable umap;
    }
    THashMap<MappingRelation, Mapping> mappings = new THashMap<MappingRelation, Mapping>();
    THashMap<TransformationRule, LocalRelation> ruleSourceMatchRelations = new THashMap<TransformationRule, LocalRelation>();
    
    // Output
    ArrayList<ERuleset.DatalogRule> sourceMatchingRules = new ArrayList<ERuleset.DatalogRule>();
    ArrayList<Statement> mappingStatements = new ArrayList<Statement>();
    TIntObjectHashMap<ArrayList<Statement>> enforcingStatements = new TIntObjectHashMap<ArrayList<Statement>>();
    
    public TransformationBuilder(ErrorLog errorLog, TypingContext context) {
        this.errorLog = errorLog;
        this.context = context;
        this.unifiableFactory = new UnifiableFactory(context, mappingStatements);
    }
    
    private Mapping getMapping(MappingRelation mappingRelation) {
        Mapping mapping = mappings.get(mappingRelation);
        if(mapping == null) {
            mapping = new Mapping();
            mapping.relation =  new LocalRelation(mappingRelation.name.name+"_src", new Type[] {
                    mappingRelation.parameterTypes[0]
            });
            mapping.umap = new Variable("map_" + mappingRelation.name.name,
                    Types.apply(Names.Unifiable_UMap, mappingRelation.parameterTypes)
                    );
            mappings.put(mappingRelation, mapping);
            mappingStatements.add(new LetStatement(new EVariable(mapping.umap),
                    Expressions.apply(context.getCompilationContext(), Types.PROC, Names.Unifiable_createUMap,
                            mappingRelation.parameterTypes[0],
                            mappingRelation.parameterTypes[1],
                            Expressions.punit())));
        }
        return mapping;
    }
    
    private static class PatternAnalyzer implements VariableProcedure {
        THashSet<Variable> variableSet;
        TObjectIntHashMap<Variable> mappedVariableUseCount;
        boolean containsVariables;
        
        public PatternAnalyzer(THashSet<Variable> variableSet,
                TObjectIntHashMap<Variable> mappedVariableUseCount) {
            this.variableSet = variableSet;
            this.mappedVariableUseCount = mappedVariableUseCount;
        }

        @Override
        public void execute(long location, Variable variable) {
            if(!variableSet.contains(variable))
                return;
            
            mappedVariableUseCount.adjustOrPutValue(variable, 1, 1);
            
            containsVariables = true;
        }
    }
    
    private static Expression statementsToExpression(EnvironmentalContext context, List<Statement> statements, Expression in) {
        for(int i=statements.size()-1;i>=0;--i)
            in = statements.get(i).toExpression(context, false, in);
        return in;
    }
    
    private static Expression statementsToExpression(EnvironmentalContext context, List<Statement> statements) {
        return statementsToExpression(context, statements, Expressions.tuple());
    }
    
    public void handleSeed(Query query) {
        if(query instanceof QMapping) {
            QMapping mapping = (QMapping)query;
            Mapping m = getMapping(mapping.mappingRelation);
            sourceMatchingRules.add(new ERuleset.DatalogRule(
                    query.location,
                    m.relation,
                    new Expression[] {mapping.parameters[0]},
                    new QConjunction(),
                    Variable.EMPTY_ARRAY
                    ));
            mappingStatements.add(new GuardStatement(unifiableFactory.putToUMapConstant(m.umap,
                    mapping.parameters[0].copy(context),
                    mapping.parameters[1].copy(context))));
        }
        else if(query instanceof QConjunction) {
            QConjunction conjunction = (QConjunction)query;
            for(Query childQuery : conjunction.queries)
                handleSeed(childQuery);
        }
        else {
            errorLog.log(query.location, "Cannot use the query as a seed for the transformation.");
        }
    }

    public void handleRule(TransformationRule rule) {
        // Collect and classify queries
        final DecomposedRule decomposed = DecomposedRule.decompose(context, rule, true);
        for(QMapping mapping : decomposed.sourceMappings) {
            decomposed.sourceQueries.add(new QAtom(
                    getMapping(mapping.mappingRelation).relation,
                    Type.EMPTY_ARRAY,
                    mapping.parameters[0].copy(context)
                    ));
        }
        
        // Source variables
        /* 
         * Collect the existential variables occurring in the rule so that
         *     sourceVariables = the variables that can be solved with source patterns 
         *                       including the sources of the mapping relations in when section
         *     variableSet = all other existential variables that are solved in mapping/enforcing phases 
         */
        final THashSet<Variable> variableSet = new THashSet<Variable>(rule.variables.length);
        for(Variable variable : rule.variables)
            variableSet.add(variable);

        Variable[] sourceVariables;
        {
            final ArrayList<Variable> sourceVariableList = new ArrayList<Variable>(rule.variables.length);
            VariableProcedure analyze = new VariableProcedure() {
                @Override
                public void execute(long location, Variable variable) {
                    if(variableSet.remove(variable))
                        sourceVariableList.add(variable);
                }
            };
            for(Query query : decomposed.sourceQueries)
                query.forVariables(analyze);
            VariableProcedure check = new VariableProcedure() {
                @Override
                public void execute(long location, Variable variable) {
                    if(variableSet.contains(variable))
                        errorLog.log(location, "Cannot resolve the variable " + variable.getName() + " using the source patterns.");
                }
            };
            for(QMapping mapping : decomposed.targetMappings)
                mapping.parameters[0].forVariableUses(check);
            sourceVariables = sourceVariableList.toArray(new Variable[sourceVariableList.size()]);
        }
        
        // Matching rules
        generateMatchingRules(decomposed, sourceVariables);
        
        // Mapped variables
        ArrayList<QMapping> mappings = new ArrayList<QMapping>(
                decomposed.sourceMappings.size() + decomposed.targetMappings.size());
        mappings.addAll(decomposed.sourceMappings);
        mappings.addAll(decomposed.targetMappings);

        // Analyze mappings
        int capacity = Math.max(10, mappings.size());
        ArrayList<QMapping> closedMappings = new ArrayList<QMapping>(capacity);
        ArrayList<QMapping> openMappings = new ArrayList<QMapping>(capacity);
        ArrayList<QMapping> semiopenMappings = new ArrayList<QMapping>(capacity);

        TObjectIntHashMap<Variable> mappedVariableUseCount = new TObjectIntHashMap<Variable>();
        for(QMapping mapping : mappings) {
            Expression expression = mapping.parameters[1];
            if(expression instanceof EVariable) {
                Variable variable = ((EVariable)expression).getVariable();
                if(variableSet.contains(variable)) {
                    // Single open variable
                    mappedVariableUseCount.adjustOrPutValue(variable, 1, 1);
                    openMappings.add(mapping);
                }
                else {
                    // Single variable whose value is bound
                    closedMappings.add(mapping);
                }
            }
            else {
                PatternAnalyzer analyzer = new PatternAnalyzer(variableSet, mappedVariableUseCount); 
                expression.forVariableUses(analyzer);

                if(analyzer.containsVariables)
                    semiopenMappings.add(mapping);
                else
                    closedMappings.add(mapping);
            }
        }

        // Generate mapping actions
        ArrayList<Statement> phase2Actions = new ArrayList<Statement>();
        ArrayList<Statement> phase3Actions = new ArrayList<Statement>();
        for(QMapping mapping : closedMappings)
            phase2Actions.add(new GuardStatement(unifiableFactory.putToUMapConstant(
                    getMapping(mapping.mappingRelation).umap,
                    mapping.parameters[0].copy(context),
                    mapping.parameters[1].copy(context))));

        // Choose and initialize shared unification variables
        THashMap<Variable, Variable> uniVariableMap =
                new THashMap<Variable, Variable>();
        for(Variable variable : mappedVariableUseCount.keySet()) {
            int count = mappedVariableUseCount.get(variable);
            if(count > 1) {
                Variable uniVariable = new Variable("uvar_" + variable.getName(),
                        Types.apply(Names.Unifiable_Unifiable, variable.getType()));
                phase2Actions.add(new LetStatement(new EVariable(uniVariable), 
                        Expressions.apply(context.getCompilationContext(), Types.PROC,
                                Names.Unifiable_uVar,
                                variable.getType(),
                                Expressions.tuple())));
                uniVariableMap.put(variable, uniVariable);
            }
        }

        // Select open mappings that use shared variables
        THashSet<Variable> undeterminedVariables = new THashSet<Variable>(variableSet); 
        for(QMapping mapping : openMappings) {
            Variable variable = ((EVariable)mapping.parameters[1]).getVariable();
            if(uniVariableMap.containsKey(variable))
                semiopenMappings.add(mapping);
            else {
                Mapping m = getMapping(mapping.mappingRelation);
                Type resultType = mapping.mappingRelation.parameterTypes[1];
                phase3Actions.add(new LetStatement(new EVariable(variable),
                        unifiableFactory.getFromUMap(Expressions.var(m.umap), mapping.parameters[0].copy(context), resultType)));
                undeterminedVariables.remove(variable);
            }
        }

        for(QMapping mapping : semiopenMappings) {
            Mapping m = getMapping(mapping.mappingRelation);
            Type valueType = mapping.mappingRelation.parameterTypes[1];
            phase2Actions.add(new GuardStatement(unifiableFactory.putToUMapUnifiable(
                    variableSet, uniVariableMap,
                    Expressions.var(m.umap), mapping.parameters[0].copy(context), mapping.parameters[1].copy(context))));
            
            Expression pattern = toPattern(undeterminedVariables, mapping.parameters[1]);
            if(pattern != null) {
                Expression value = unifiableFactory.getFromUMap(Expressions.var(m.umap), mapping.parameters[0].copy(context), valueType);
                phase3Actions.add(new LetStatement(pattern, value));
            }
        }
        
        // Mapping statement
        if(!phase2Actions.isEmpty())
            mappingStatements.add(new GuardStatement(new EWhen(
                    rule.location,
                    new QAtom(decomposed.ruleMatchingRelation,
                            Type.EMPTY_ARRAY,
                            Expressions.vars(sourceVariables)),
                            statementsToExpression(context.getCompilationContext(), phase2Actions),
                            sourceVariables).compile(context)));

        // Enforcing statement
        if(!decomposed.targetQueries.isEmpty()) {
            for(Variable variable : rule.variables)
                if(variableSet.contains(variable) && !mappedVariableUseCount.containsKey(variable))
                    phase3Actions.add(new LetStatement(new EVariable(variable), unifiableFactory.generateDefaultValue(variable.getType())));
            
            TIntObjectHashMap<ArrayList<Query>> phases = new TIntObjectHashMap<ArrayList<Query>>();
            for(Query targetQuery : decomposed.targetQueries)
                targetQuery.splitToPhases(phases);
            
            for(int phase : phases.keys()) {
                ArrayList<Query> targetQuery = phases.get(phase);
                Expression enforcing = new EEnforce(new QConjunction(targetQuery.toArray(new Query[targetQuery.size()]))).compile(context);
                enforcing = statementsToExpression(context.getCompilationContext(), phase3Actions, enforcing);
                enforcing = new EWhen(
                        rule.location,
                        new QAtom(decomposed.ruleMatchingRelation,
                                Type.EMPTY_ARRAY,
                                Expressions.vars(sourceVariables)),
                                enforcing,
                                sourceVariables).compile(context);
                ArrayList<Statement> list = enforcingStatements.get(phase);
                if(list == null) {
                    list = new ArrayList<Statement>();
                    enforcingStatements.put(phase, list);
                }
                list.add(new GuardStatement(ForcedClosure.forceClosure(enforcing.copy(context),
                        SCLCompilerConfiguration.EVERY_RULE_ENFORCEMENT_IN_SEPARATE_METHOD)));
            }
        }
    }
    
    public Expression compileRules() {
        ArrayList<LocalRelation> localRelations = new ArrayList<LocalRelation>();
        localRelations.addAll(ruleSourceMatchRelations.values());
        for(Mapping mapping : mappings.values())
            localRelations.add(mapping.relation);
        
        ArrayList<Statement> allEnforcingStatements;
        if(enforcingStatements.size() == 1)
            allEnforcingStatements = enforcingStatements.valueCollection().iterator().next();
        else {
            int[] phases = enforcingStatements.keys();
            Arrays.sort(phases);
            allEnforcingStatements = new ArrayList<Statement>();
            for(int phase : phases)
                allEnforcingStatements.addAll(enforcingStatements.get(phase));
        }
        Expression expression = statementsToExpression(context.getCompilationContext(), allEnforcingStatements);
        expression = statementsToExpression(context.getCompilationContext(), mappingStatements, expression);
        
        // Matching
        Expression result = new ERuleset(
                localRelations.toArray(new LocalRelation[localRelations.size()]),
                sourceMatchingRules.toArray(new ERuleset.DatalogRule[sourceMatchingRules.size()]),
                expression
                ).compile(context);
        return result;
    }
        
    private Expression toPattern(
            THashSet<Variable> undeterminedVariables,
            Expression expression) {
        if(expression instanceof EVariable) {
            Variable variable = ((EVariable)expression).getVariable();
            if(undeterminedVariables.remove(variable))
                return new EVariable(variable);
            else
                return null;
        }
        if(expression instanceof EApply) {
            EApply apply = (EApply)expression;
            
            if(!(apply.getFunction() instanceof EConstant))
                return null;
            EConstant function = (EConstant)apply.getFunction();
            
            IVal val = function.getValue().getValue();
            if(!(val instanceof Constant))
                return null;
            Constant constant = (Constant)val;
            
            int constructorTag = constant.constructorTag();
            if(constructorTag < 0)
                return null;
            
            int arity = constant.getArity();
            Expression[] parameters = apply.getParameters(); 
            if(arity != parameters.length)
                return null;
            
            Expression[] patterns = new Expression[arity];
            boolean noUndeterminedVariables = true;
            for(int i=0;i<arity;++i) {
                Expression pattern = toPattern(undeterminedVariables, parameters[i]); 
                patterns[i] = pattern;
                if(pattern != null)
                    noUndeterminedVariables = false;
            }
            if(noUndeterminedVariables)
                return null;
            
            for(int i=0;i<arity;++i)
                if(patterns[i] == null)
                    patterns[i] = Expressions.blank(parameters[i].getType());
            return new EApply(Locations.NO_LOCATION, apply.getEffect(), apply.getFunction().copy(context), patterns);
        } 
        
        return null;
    }

    private void generateMatchingRules(DecomposedRule decomposed, Variable[] sourceVariables) {
        // @when/from -sections
        decomposed.ruleMatchingRelation =
                new LocalRelation(decomposed.rule.name.name+"_match", Types.getTypes(sourceVariables));
        ruleSourceMatchRelations.put(decomposed.rule, decomposed.ruleMatchingRelation);
        sourceMatchingRules.add(new ERuleset.DatalogRule(decomposed.rule.location,
                decomposed.ruleMatchingRelation,
                Expressions.vars(sourceVariables), 
                new QConjunction(decomposed.sourceQueries.toArray(new Query[decomposed.sourceQueries.size()])),
                sourceVariables));
        
        // @where -section
        for(QMapping mapping : decomposed.targetMappings)
            sourceMatchingRules.add(new ERuleset.DatalogRule(decomposed.rule.location,
                    getMapping(mapping.mappingRelation).relation,
                    new Expression[] {mapping.parameters[0].copy(context)},
                    new QAtom(decomposed.ruleMatchingRelation,
                            Type.EMPTY_ARRAY,
                            Expressions.vars(sourceVariables)),
                            sourceVariables));
    }
}
