package org.simantics.scl.compiler.elaboration.expressions;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.elaboration.chr.CHRRuleset;
import org.simantics.scl.compiler.elaboration.chr.translation.CHRTranslation;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.expressions.block.CHRStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.ConstraintStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.GuardStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.LetStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.RuleStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.Statement;
import org.simantics.scl.compiler.elaboration.expressions.block.StatementGroup;
import org.simantics.scl.compiler.errors.Locations;

public class EBlock extends ASTExpression {

    LinkedList<Statement> statements = new LinkedList<Statement>();
    boolean monadic;
    
    public EBlock() {
    }

    public void addStatement(Statement statement) {
        statements.add(statement);
    }
    
    public void setMonadic(boolean monadic) {
        this.monadic = monadic;
    }
    
    public LinkedList<Statement> getStatements() {
        return statements;
    }

    @Override
    public Expression resolve(TranslationContext context) {
        if(statements.isEmpty())
            throw new InternalCompilerError();
        int i = statements.size()-1;
        Statement last = statements.get(i);
        if(!(last instanceof GuardStatement)) {
            context.getErrorLog().log(last.location, "Block should end with an expression");
            return new EError(location);
        }

        Expression in = ((GuardStatement)last).value;
        while(--i >= 0) {
            Statement cur = statements.get(i);
            StatementGroup group = cur.getStatementGroup();
            if(group == null)
                in = cur.toExpression(context, monadic, in);
            else {
                int endId = i+1;
                while(i>0 && statements.get(i-1).getStatementGroup() == group)
                    --i;
                switch(group) {
                case LetFunction:
                    in = extractLet(i, endId, in);
                    break;
                case Rule:
                    in = extractRules(i, endId, in);
                    break;
                case CHR:
                    in = extractCHRRules(context, i, endId, in);
                    break;
                }
            }
        }
        return in.resolve(context);
    }

    private Expression extractRules(int begin, int end, Expression in) {
        return new EPreRuleset(statements.subList(begin, end).toArray(new RuleStatement[end-begin]), in);
    }
    
    private Expression extractCHRRules(TranslationContext context, int begin, int end, Expression in) {
        CHRRuleset ruleset = new CHRRuleset();
        ruleset.location = Locations.combine(statements.get(begin).location, statements.get(end-1).location);
        for(int i=begin;i<end;++i) {
            Statement statement = statements.get(i);
            if(statement instanceof CHRStatement)
                ruleset.rules.add(CHRTranslation.convertCHRStatement(context, (CHRStatement)statement));
            else if(statement instanceof ConstraintStatement)
                ruleset.constraints.add(CHRTranslation.convertConstraintStatement(context, (ConstraintStatement)statement));
            else
                throw new InternalCompilerError("Invalid CHR statement.");
        }
        return new ECHRRuleset(ruleset, in);
    }

    @SuppressWarnings("unchecked")
    private Expression extractLet(int begin, int end, Expression in) {
        return new EPreLet((List<LetStatement>)(List<?>)statements.subList(begin, end), in);
    }

    public static Expression create(ArrayList<Expression> statements) {
        EBlock block = new EBlock();
        for(Expression statement : statements)
            block.addStatement(new GuardStatement(statement));
        return block;
    }

    @Override
    public void setLocationDeep(long loc) {
        if(location == Locations.NO_LOCATION) {
            location = loc;
            for(Statement statement : statements)
                statement.setLocationDeep(loc);
        }
    }
    
    @Override
    public Expression accept(ExpressionTransformer transformer) {
        return transformer.transform(this);
    }

    @Override
    public int getSyntacticFunctionArity() {
        if(monadic)
            return 0;
        Statement lastStatement = statements.getLast();
        if(!(lastStatement instanceof GuardStatement))
            return 0;
        return ((GuardStatement)lastStatement).value.getSyntacticFunctionArity();
    }
    
    @Override
    public void accept(ExpressionVisitor visitor) {
        visitor.visit(this);
    }
}
