package org.simantics.scl.compiler.elaboration.expressions;

import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;

import java.util.ArrayList;
import java.util.Set;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.common.names.Names;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
import org.simantics.scl.compiler.elaboration.query.Query;
import org.simantics.scl.compiler.elaboration.query.Query.Diff;
import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.kinds.Kinds;

import gnu.trove.impl.Constants;
import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

public class ERuleset extends SimplifiableExpression {
    LocalRelation[] relations;
    DatalogRule[] rules;
    Expression in;
    
    public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
        this.relations = relations;
        this.rules = rules;
        this.in = in;
    }

    public static class DatalogRule {
        public long location;
        public LocalRelation headRelation;
        public Expression[] headParameters;
        public Query body;
        public Variable[] variables;
        
        public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
                Query body) {
            this.headRelation = headRelation;
            this.headParameters = headParameters;
            this.body = body;
        }
        
        public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
                Query body, Variable[] variables) {
            this.location = location;
            this.headRelation = headRelation;
            this.headParameters = headParameters;
            this.body = body;
            this.variables = variables;
        }

        public void setLocationDeep(long loc) {
            this.location = loc;
            for(Expression parameter : headParameters)
                parameter.setLocationDeep(loc);
            body.setLocationDeep(loc);
        }
        
        @Override
        public String toString() {
            StringBuilder b = new StringBuilder();
            ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
            visitor.visit(this);
            return b.toString();
        }

        public void forVariables(VariableProcedure procedure) {
            for(Expression headParameter : headParameters)
                headParameter.forVariables(procedure);
            body.forVariables(procedure);
        }
    }
    
    private void checkRuleTypes(TypingContext context) {
        // Create relation variables
        for(DatalogRule rule : rules) {
            Type[] parameterTypes =  rule.headRelation.getParameterTypes();
            Expression[] parameters = rule.headParameters;
            for(Variable variable : rule.variables)
                variable.setType(Types.metaVar(Kinds.STAR));
            for(int i=0;i<parameters.length;++i)
                parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
            rule.body.checkType(context);
        }
    }
    
    @Override
    public Expression checkBasicType(TypingContext context, Type requiredType) {
        checkRuleTypes(context);
        in = in.checkBasicType(context, requiredType);
        return compile(context);
    }
    
    @Override
    public Expression inferType(TypingContext context) {
        checkRuleTypes(context);
        in = in.inferType(context);
        return compile(context);
    }
    
    @Override
    public Expression checkIgnoredType(TypingContext context) {
        checkRuleTypes(context);
        in = in.checkIgnoredType(context);
        return compile(context);
    }
    
    @Override
    public void collectFreeVariables(THashSet<Variable> vars) {
        for(DatalogRule rule : rules) {
            for(Expression parameter : rule.headParameters)
                parameter.collectFreeVariables(vars);
            rule.body.collectFreeVariables(vars);
            for(Variable var : rule.variables)
                vars.remove(var);
        }
        in.collectFreeVariables(vars);
    }
    
    @Override
    public void collectRefs(TObjectIntHashMap<Object> allRefs,
            TIntHashSet refs) {
        for(DatalogRule rule : rules) {
            for(Expression parameter : rule.headParameters)
                parameter.collectRefs(allRefs, refs);
            rule.body.collectRefs(allRefs, refs);
        }
        in.collectRefs(allRefs, refs);
    }
    
    @Override
    public void collectVars(TObjectIntHashMap<Variable> allVars,
            TIntHashSet vars) {
        for(DatalogRule rule : rules) {
            for(Expression parameter : rule.headParameters)
                parameter.collectVars(allVars, vars);
            rule.body.collectVars(allVars, vars);
        }
        in.collectVars(allVars, vars);
    }
    
    @Override
    public void collectEffects(THashSet<Type> effects) {
        throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");
    }
    
    @Override
    public Expression decorate(ExpressionDecorator decorator) {
        return decorator.decorate(this);
    }
    
    @Override
    public Expression resolve(TranslationContext context) {
        throw new InternalCompilerError();
    }
    
    static class LocalRelationAux {
        Variable handleFunc;
    }
    
    public Expression compile(TypingContext context) {
        // Create a map from relations to their ids
        TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
                Constants.DEFAULT_LOAD_FACTOR, -1);
        for(int i=0;i<relations.length;++i)
            relationsToIds.put(relations[i], i);
        
        // Create a table from relations to the other relations they depend on
        TIntHashSet[] refsSets = new TIntHashSet[relations.length];
        int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
        for(int i=0;i<relations.length;++i)
            refsSets[i] = new TIntHashSet(setCapacity);
        
        for(DatalogRule rule : rules) {
            int headRelationId = relationsToIds.get(rule.headRelation);
            TIntHashSet refsSet = refsSets[headRelationId];
            rule.body.collectRelationRefs(relationsToIds, refsSet);
            for(Expression parameter : rule.headParameters)
                parameter.collectRelationRefs(relationsToIds, refsSet);
        }
        
        // Convert refsSets to an array
        final int[][] refs = new int[relations.length][];
        for(int i=0;i<relations.length;++i)
            refs[i] = refsSets[i].toArray();
        
        // Find strongly connected components of the function refs
        final ArrayList<int[]> components = new ArrayList<int[]>();
        
        new StronglyConnectedComponents(relations.length) {
            @Override
            protected void reportComponent(int[] component) {
                components.add(component);
            }
            
            @Override
            protected int[] findDependencies(int u) {
                return refs[u];
            }
        }.findComponents();
        
        // If there is just one component, compile it
        if(components.size() == 1) {
            return compileStratified(context);
        }
        
        // Inverse of components array 
        int[] strataPerRelation = new int[relations.length];
        for(int i=0;i<components.size();++i)
            for(int k : components.get(i))
                strataPerRelation[k] = i;
        
        // Collects rules belonging to each strata
        @SuppressWarnings("unchecked")
        ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
        for(int i=0;i<components.size();++i)
            rulesPerStrata[i] = new ArrayList<DatalogRule>();
        for(DatalogRule rule : rules) {
            int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
            rulesPerStrata[stratum].add(rule);
        }
        
        // Create stratified system
        Expression cur = this.in;
        for(int stratum=components.size()-1;stratum >= 0;--stratum) {
            int[] cs = components.get(stratum);
            LocalRelation[] curRelations = new LocalRelation[cs.length];
            for(int i=0;i<cs.length;++i)
                curRelations[i] = relations[cs[i]];
            ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
            cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
        }
        return cur;
    }
    
    private Expression compileStratified(TypingContext context) {
        Expression continuation = Expressions.tuple();
        
        // Create stacks
        Variable[] stacks = new Variable[relations.length];
        for(int i=0;i<relations.length;++i) {
            LocalRelation relation = relations[i];
            Type[] parameterTypes = relation.getParameterTypes();
            stacks[i] = newVar("stack" + relation.getName(),
                    Types.apply(Names.MList_T, Types.tuple(parameterTypes))
                    );
        }

        // Simplify subexpressions and collect derivatives
        THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
        for(int i=0;i<relations.length;++i) {
            LocalRelation relation = relations[i];
            Type[] parameterTypes = relation.getParameterTypes();
            Variable[] parameters = new Variable[parameterTypes.length];
            for(int j=0;j<parameterTypes.length;++j)
                parameters[j] = new Variable("p" + j, parameterTypes[j]);
            diffables.put(relations[i], new Diffable(i, relation, parameters));
        }
        @SuppressWarnings("unchecked")
        ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
        for(int i=0;i<relations.length;++i)
            updateExpressions[i] = new ArrayList<Expression>(2);
        ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); 
        for(DatalogRule rule : rules) {
            int id = diffables.get(rule.headRelation).id;
            Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
                    var(stacks[id]),
                    tuple(rule.headParameters)
                    );
            Diff[] diffs;
            try {
                diffs = rule.body.derivate(diffables);
            } catch(DerivateException e) {
                context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
                return new EError();
            }
            for(Diff diff : diffs)
                updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
            if(diffs.length == 0)
                seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
            else {
                Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
                if(query != Query.EMPTY_QUERY)
                    seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
            }
        }
        
        // Iterative solving of relations

        Variable[] loops = new Variable[relations.length];
        for(int i=0;i<loops.length;++i)
            loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
        continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
        
        Expression[] loopDefs = new Expression[relations.length];
        for(int i=0;i<relations.length;++i) {
            LocalRelation relation = relations[i];
            Type[] parameterTypes = relation.getParameterTypes();
            Variable[] parameters = diffables.get(relation).parameters;
            
            Variable counter = newVar("counter", Types.INTEGER);
            
            Type rowType = Types.tuple(parameterTypes);
            Variable row = newVar("row", rowType);
            
            Expression handleRow = tuple();
            for(Expression updateExpression : updateExpressions[i])
                handleRow = seq(updateExpression, handleRow);
            handleRow = if_(
                    apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
                            var(relation.table), var(row)),
                    handleRow,
                    tuple()
                    );
            handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
            Expression failure =
                    if_(isZeroInteger(var(counter)),
                        tuple(),
                        apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
                       );
            Expression body = matchWithDefault(
                    apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
                    Just(as(row, tuple(vars(parameters)))), handleRow,
                    failure);
            
            loopDefs[i] = lambda(Types.PROC, counter, body); 
        }
        continuation = letRec(loops, loopDefs, continuation);
        
        // Seed relations
        for(Expression seedExpression : seedExpressions)
            continuation = seq(seedExpression, continuation);
        
        // Create stacks
        for(int i=0;i<stacks.length;++i)
            continuation = let(stacks[i],
                    apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
                    continuation);
        
        continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
        
        // Create relations
        for(LocalRelation relation : relations)
            continuation = let(relation.table,
                    apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
                    continuation);
        
        return seq(continuation, in);
    }

    @Override
    protected void updateType() throws MatchException {
        setType(in.getType());
    }
    
    @Override
    public void setLocationDeep(long loc) {
        if(location == Locations.NO_LOCATION) {
            location = loc;
            for(DatalogRule rule : rules)
                rule.setLocationDeep(loc);
        }
    }
    
    @Override
    public void accept(ExpressionVisitor visitor) {
        visitor.visit(this);
    }

    public DatalogRule[] getRules() {
        return rules;
    }
    
    public Expression getIn() {
        return in;
    }

    @Override
    public void forVariables(VariableProcedure procedure) {
        for(DatalogRule rule : rules)
            rule.forVariables(procedure);
        in.forVariables(procedure);
    }
    
    @Override
    public Expression accept(ExpressionTransformer transformer) {
        return transformer.transform(this);
    }

}
