package org.simantics.scl.compiler.elaboration.expressions;

import java.util.ArrayList;
import java.util.Set;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.common.precedence.Precedence;
import org.simantics.scl.compiler.compilation.CompilationContext;
import org.simantics.scl.compiler.constants.NoRepConstant;
import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.errors.NotPatternException;
import org.simantics.scl.compiler.elaboration.expressions.lhstype.LhsType;
import org.simantics.scl.compiler.elaboration.expressions.lhstype.PatternMatchingLhs;
import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.CollectEffectsVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.CollectFreeVariablesVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.CollectRefsVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.CollectVarsVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.ForVariablesUsesVisitor;
import org.simantics.scl.compiler.elaboration.expressions.visitors.StandardExpressionVisitor;
import org.simantics.scl.compiler.elaboration.query.QAtom;
import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.elaboration.decomposed.DecomposedExpression;
import org.simantics.scl.compiler.internal.interpreted.IExpression;
import org.simantics.scl.compiler.internal.parsing.Symbol;
import org.simantics.scl.compiler.top.ExpressionInterpretationContext;
import org.simantics.scl.compiler.types.TForAll;
import org.simantics.scl.compiler.types.TFun;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.kinds.Kinds;
import org.simantics.scl.compiler.types.util.Typed;

import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.TIntHashSet;

public abstract class Expression extends Symbol implements Typed {
    public static final Expression[] EMPTY_ARRAY = new Expression[0];
    
    transient
    private Type type;
    
    public Expression() {
    }
    
    public Expression(long loc) {
        this.location = loc;
    }
        
    @Override
    public Type getType() {
        if(type == null) {
            try {
                updateType();
            } catch (MatchException e) {
                throw new InternalCompilerError(location, e);
            }
            if(type == null)
                throw new InternalCompilerError(location, getClass().getSimpleName() + 
                        ".updateType couldn't compute its type.");
        }
        return type;
    }

    public void setType(Type type) {
        if(type == null)
            throw new NullPointerException();
        this.type = type;
    }
	
	/**
	 * Infers the type of the expression without any context. Adds type
     * applications and lambdas if needed.	 
	 */
    public Expression inferType(TypingContext context) {
        return checkBasicType(context, Types.metaVar(Kinds.STAR));
    }

    public Expression checkBasicType(TypingContext context, Type requiredType) {
        return context.subsume(inferType(context), requiredType);
    }
    
    protected Expression applyPUnit(TypingContext context) {
        Type type = Types.canonical(getType());
        if(type instanceof TFun) {
            TFun fun = (TFun)type;
            if(fun.getCanonicalDomain() == Types.PUNIT) {
                EApply result = new EApply(location, this, new ELiteral(NoRepConstant.PUNIT));
                result.effect = fun.getCanonicalEffect();
                context.declareEffect(this.location, result.effect);
                return result;
            }
        }
        return this;
    }

    public Expression checkIgnoredType(TypingContext context) {
        Expression expression = inferType(context);
        if(Types.canonical(expression.getType()) != Types.UNIT)
            expression = new ESimpleLet(location, null, expression, new ELiteral(NoRepConstant.PUNIT));
        return expression;
    }

    /**
     * Checks the type of the expression against the given type. Adds type
     * applications and lambdas if needed.
     */
    public final Expression checkType(TypingContext context, Type requiredType) {
        //System.out.println("checkType: " + this + " :: " + requiredType);
        if(!context.isInPattern()) {
            requiredType = Types.canonical(requiredType);
            if(requiredType instanceof TForAll) {
                TForAll forAll = (TForAll)requiredType;
                TVar var = forAll.var;
                TVar newVar = Types.var(var.getKind());
                requiredType = Types.canonical(forAll.type).replace(var, newVar);
                return new ELambdaType(new TVar[] {newVar}, checkType(context, requiredType));
            }
            while(requiredType instanceof TFun) {
                TFun fun = (TFun)requiredType;
                if(fun.domain instanceof TPred) { // No need to canonicalize
                    ArrayList<Variable> constraints = new ArrayList<Variable>(2);
                    while(true) {
                        constraints.add(new Variable("constraint", fun.domain));
                        requiredType = Types.canonical(fun.range);
                        if(!(requiredType instanceof TFun))
                            break;
                        fun = (TFun)requiredType;
                        if(!(fun.domain instanceof TPred))
                            break;
                    }
                    context.pushConstraintFrame(constraints.toArray(new Variable[constraints.size()]));
                    Expression expression = checkType(context, requiredType);
                    context.popConstraintFrame();
                    for(int i=constraints.size()-1;i>=0;--i)
                        expression = new ESimpleLambda(constraints.get(i), expression);
                    return expression;
                }
                else if(fun.domain == Types.PUNIT) {
                    context.pushEffectUpperBound(location, fun.effect);
                    Expression expr = checkType(context, fun.range);
                    context.popEffectUpperBound();       

                    // Wrap
                    Variable var = new Variable("punit", Types.PUNIT);
                    return new ESimpleLambda(location, var, fun.effect, expr);
                }
                else
                    break;
            }
        }
        return checkBasicType(context, requiredType); 
    }

    public final void collectRefs(TObjectIntHashMap<Object> allRefs, TIntHashSet refs) {
        accept(new CollectRefsVisitor(allRefs, refs));
    }

    public final void collectVars(TObjectIntHashMap<Variable> allVars, TIntHashSet vars) {
        accept(new CollectVarsVisitor(allVars, vars));
    }

    public final void forVariableUses(VariableProcedure procedure) {
        accept(new ForVariablesUsesVisitor(procedure));
    }

    public Expression decomposeMatching() {
        return this;
    }

	public String toString() {
	    StringBuilder b = new StringBuilder();
	    ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
	    accept(visitor);
	    return b.toString();
    }

	protected abstract void updateType() throws MatchException;
	
	public static class TypeValidationException extends Exception {
        private static final long serialVersionUID = 3181298127162041248L;  
        
        long loc;

        public TypeValidationException(long loc) {
            this.loc = loc;
        }
        
        public long getLoc() {
            return loc;
        }

        public TypeValidationException(long loc, Throwable cause) {
            super(cause);
            this.loc = loc;
        }
    }
    
    public static void assertEquals(long loc, Type a, Type b) throws TypeValidationException {
        if(!Types.equals(a, b))
            throw new TypeValidationException(loc);
    }

	public abstract IVal toVal(CompilationContext context, CodeWriter w);
		
	public Expression closure(TVar ... vars) {
	    if(vars.length == 0)
            return this;
        return new ELambdaType(vars, this);
	}
    
    public Expression simplify(SimplificationContext context) {
        System.out.println("#############################");
        System.out.println(this);
        throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support simplify method.");
    }

    public abstract Expression resolve(TranslationContext context);
    
    /**
     * Returns head of the pattern.
     */
    public EVar getPatternHead() throws NotPatternException {
        throw new NotPatternException(this);
    }
    
    public LhsType getLhsType() throws NotPatternException {
        throw new NotPatternException(this);
    }

    protected void collectVariableNames(PatternMatchingLhs lhsType) throws NotPatternException {
        throw new NotPatternException(this);
    }

    public void getParameters(TranslationContext translationContext,
            ArrayList<Expression> parameters) {
        throw new InternalCompilerError(location, "Class " + getClass().getSimpleName() + " does not support getParameters.");        
    }

    public Expression resolveAsPattern(TranslationContext context) {
        context.getErrorLog().log(location, "Pattern was expected here.");
        return new EError();
    }
    
    public Expression checkTypeAsPattern(TypingContext context, Type requiredType) {
        if(context.isInPattern())
            throw new InternalCompilerError(location, "Already in a pattern.");
        context.setInPattern(true);
        Expression expression = checkType(context, requiredType);
        context.setInPattern(false);
        return expression;
    }

    /**
     * Used during simplification and in toIExpression
     */
    public Set<Variable> getFreeVariables() {
        CollectFreeVariablesVisitor visitor = new CollectFreeVariablesVisitor(); 
        accept(visitor);
        return visitor.getFreeVariables();
    }

    public static Expression[] concat(Expression[] a, Expression[] b) {
        if(a.length == 0)
            return b;
        if(b.length == 0)
            return a;
        Expression[] result = new Expression[a.length + b.length];
        for(int i=0;i<a.length;++i)
            result[i] = a[i];
        for(int i=0;i<b.length;++i)
            result[i+a.length] = b[i];
        return result;
    }

    public Expression replace(ReplaceContext context) {
        throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support replace.");
    }
    
    public static Expression[] replace(ReplaceContext context, Expression[] expressions) {
        Expression[] result = new Expression[expressions.length];
        for(int i=0;i<expressions.length;++i)
            result[i] = expressions[i].replace(context);
        return result;
    }
    
    public Expression copy() {
        return replace(new ReplaceContext(null));
    }
    
    public Expression copy(TypingContext typingContext) {
        return replace(new ReplaceContext(typingContext));
    }

    public abstract void setLocationDeep(long loc);

    public Expression replaceInPattern(ReplaceContext context) {
        context.inPattern = true;
        Expression result = replace(context);
        context.inPattern = false;
        return result;
    }

    public int getFunctionDefinitionPatternArity() throws NotPatternException {
        throw new NotPatternException(this);
    }
    
    public IVal lambdaToVal(CompilationContext context, CodeWriter w) {
        DecomposedExpression decomposed = DecomposedExpression.decompose(context.errorLog, this);
        CodeWriter newW = w.createFunction(decomposed.typeParameters, decomposed.effect, decomposed.returnType, decomposed.parameterTypes);
        IVal[] parameters = newW.getParameters();
        IVal functionVal = newW.getFunction().getTarget();
        for(int i=0;i<parameters.length;++i)
            decomposed.parameters[i].setVal(parameters[i]);
        newW.return_(decomposed.body.toVal(context, newW));
        return functionVal;
    }
    
    public IExpression toIExpression(ExpressionInterpretationContext context) {
        throw new UnsupportedOperationException();
    }
    
    public static IExpression[] toIExpressions(ExpressionInterpretationContext target, Expression[] expressions) {
        IExpression[] result = new IExpression[expressions.length];
        for(int i=0;i<expressions.length;++i)
            result[i] = expressions[i].toIExpression(target);
        return result;
    }
    
    public Expression applyType(Type type) {
        return new EApplyType(location, this, type);
    }

	public boolean isEffectful() {
		return true;
	}

    public boolean isFunctionPattern() {
        return false;
    }

    public boolean isConstructorApplication() {
        return false;
    }
    
    public Type getEffect() {
        CollectEffectsVisitor visitor = new CollectEffectsVisitor();
        accept(visitor);
        return visitor.getCombinedEffect();
    }
    
    public abstract void accept(ExpressionVisitor visitor);
    
    public void collectRelationRefs(
            final TObjectIntHashMap<SCLRelation> allRefs, final TIntHashSet refs) {
        accept(new StandardExpressionVisitor() {
            @Override
            public void visit(QAtom query) {
                int id = allRefs.get(query.relation);
                if(id >= 0)
                    refs.add(id);
            }
        });
    }

    public boolean isFunctionDefinitionLhs() {
        return false;
    }

    public Precedence getPrecedence() {
        return Precedence.DEFAULT;
    }

    public boolean isPattern(int arity) {
        return false;
    }
    
    public abstract Expression accept(ExpressionTransformer transformer);

    // TODO implement for all expressions
    public boolean equalsExpression(Expression expression) {
        return false;
    }

    /**
     * This method returns a lower bound for the function arity of the value this expression defines.
     * The lower bound is calculated purely looking the syntax of the expression, not the
     * types of the constants and variables the expression refers to.
     */
    public int getSyntacticFunctionArity() {
        return 0;
    }
}
