package org.simantics.scl.compiler.top;

import java.io.StringReader;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.ListIterator;
import java.util.Map;

import org.simantics.scl.compiler.common.names.Name;
import org.simantics.scl.compiler.compilation.CodeGeneration;
import org.simantics.scl.compiler.compilation.CompilationContext;
import org.simantics.scl.compiler.constants.JavaStaticMethod;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.errors.NotPatternException;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.EBlock;
import org.simantics.scl.compiler.elaboration.expressions.EConstant;
import org.simantics.scl.compiler.elaboration.expressions.EExternalConstant;
import org.simantics.scl.compiler.elaboration.expressions.EVar;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.block.GuardStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.LetStatement;
import org.simantics.scl.compiler.elaboration.expressions.block.Statement;
import org.simantics.scl.compiler.elaboration.java.Builtins;
import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.environment.LocalEnvironment;
import org.simantics.scl.compiler.errors.CompilationError;
import org.simantics.scl.compiler.errors.ErrorLog;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAModule;
import org.simantics.scl.compiler.internal.codegen.types.DummyJavaReferenceValidator;
import org.simantics.scl.compiler.internal.codegen.types.JavaTypeTranslator;
import org.simantics.scl.compiler.internal.codegen.utils.CodeBuildingException;
import org.simantics.scl.compiler.internal.codegen.utils.JavaNamingPolicy;
import org.simantics.scl.compiler.internal.codegen.utils.ModuleBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.TransientClassBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.ValueFromMethod;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.codegen.writer.ExternalConstant;
import org.simantics.scl.compiler.internal.codegen.writer.ModuleWriter;
import org.simantics.scl.compiler.internal.elaboration.decomposed.DecomposedExpression;
import org.simantics.scl.compiler.internal.interpreted.IExpression;
import org.simantics.scl.compiler.internal.parsing.exceptions.SCLSyntaxErrorException;
import org.simantics.scl.compiler.internal.parsing.parser.SCLBlockParser;
import org.simantics.scl.compiler.internal.parsing.parser.SCLParserImpl;
import org.simantics.scl.compiler.internal.parsing.parser.SCLParserOptions;
import org.simantics.scl.compiler.internal.parsing.utils.LineLocators;
import org.simantics.scl.compiler.runtime.MutableClassLoader;
import org.simantics.scl.compiler.runtime.RuntimeEnvironment;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.kinds.Kinds;
import org.simantics.scl.compiler.types.util.Polarity;
import org.simantics.scl.compiler.types.util.ProcedureType;
import org.simantics.scl.runtime.function.FunctionImpl1;
import org.simantics.scl.runtime.tuple.Tuple0;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;

import gnu.trove.set.hash.THashSet;

public class ExpressionEvaluator {

    private static final Logger LOGGER = LoggerFactory.getLogger(ExpressionEvaluator.class);

    public static final boolean TRACE_INTERPRETATION_VS_COMPILATION = false;
    private static final String COMPUTATION_METHOD_NAME = "main";
    
    private final RuntimeEnvironment runtimeEnvironment;
    private final String expressionText;
    private Expression expression;
    private Type expressionType;
    
    private Type expectedEffect;
    private boolean decorateExpression;
    private Type expectedType;
    private LocalEnvironment localEnvironment;
    private LocalStorage localStorage;
    private boolean interpretIfPossible = true;
    private ExpressionParseMode parseMode = ExpressionParseMode.EXPRESSION;
    private boolean validateOnly;
    
    public ExpressionEvaluator(RuntimeEnvironment runtimeEnvironment,
            String expressionText) {
        if(runtimeEnvironment == null)
            throw new NullPointerException();
        if(expressionText == null)
            throw new NullPointerException();
        this.runtimeEnvironment = runtimeEnvironment;
        this.expressionText = expressionText;
    }
    
    public ExpressionEvaluator(RuntimeEnvironment runtimeEnvironment,
            LocalStorage localStorage, Expression expression) {
        if(runtimeEnvironment == null)
            throw new NullPointerException();
        if(expression == null)
            throw new NullPointerException();
        this.runtimeEnvironment = runtimeEnvironment;
        this.localStorage = localStorage;
        this.expressionText = null;
        this.expression = expression;
    }
    
    public ExpressionEvaluator expectedEffect(Type expectedEffect) {
        this.expectedEffect = expectedEffect;
        return this;
    }
    
    public ExpressionEvaluator decorateExpression(boolean decorateExpression) {
        this.decorateExpression = decorateExpression;
        return this;
    }
    
    public ExpressionEvaluator expectedType(Type expectedType) {
        this.expectedType = expectedType;
        return this;
    }
    
    public ExpressionEvaluator validateOnly(boolean validateOnly) {
        this.validateOnly = validateOnly;
        return this;
    }
    
    /**
     * Sets a local environment that can arbitrarily modify the resolving of the expression.
     */
    public ExpressionEvaluator localEnvironment(LocalEnvironment localEnvironment) {
        this.localEnvironment = localEnvironment;
        return this;
    }
    
    /**
     * Evaluates the expression by interpretation instead of compilation to bytecode
     * if the expression does not contain language constructs that interpretation does
     * not support.
     */
    public ExpressionEvaluator interpretIfPossible(boolean interpretIfPossible) {
        this.interpretIfPossible = interpretIfPossible;
        return this;
    }
    
    /**
     * Assumes that top level of the expression is similar to the content
     * of a do-block.
     */
    public ExpressionEvaluator parseAsBlock(boolean parseAsBlock) {
        this.parseMode = parseAsBlock ? ExpressionParseMode.BLOCK : ExpressionParseMode.EXPRESSION;
        return this;
    }
    
    public ExpressionEvaluator parseModel(ExpressionParseMode parseMode) {
        this.parseMode = parseMode;
        return this;
    }
    
    private void fillDefaults() {
        if(expectedEffect == null)
            expectedEffect = Types.metaVar(Kinds.EFFECT);
        if(expectedType == null)
            expectedType = Types.metaVar(Kinds.STAR);
    }
    
    private static class StoreFunction extends FunctionImpl1<Object, Object> {
        final LocalStorage storage;
        final String name;
        final Type type;
        public StoreFunction(LocalStorage storage, String name, Type type) {
            this.storage = storage;
            this.name = name;
            this.type = type;
        }
        @Override
        public Object apply(Object value) {
            Type type = Types.closure(this.type.convertMetaVarsToVars());
            storage.store(name, value, type);
            return Tuple0.INSTANCE;
        }
        
        @Override
        public String toString() {
            return "store_" + name;
        }
    }
    
    public CompilationError[] validate() {
        try {
            validateOnly = true;
            eval();
            return CompilationError.EMPTY_ARRAY;
        } catch(SCLExpressionCompilationException e) {
            return e.getErrors();
        }
    }

    public Object eval() throws SCLExpressionCompilationException {
        fillDefaults();
        
        final CompilationContext compilationContext = new CompilationContext();
        final ErrorLog errorLog = compilationContext.errorLog;
        final Environment environment = runtimeEnvironment.getEnvironment();
        compilationContext.environment = environment;
        
        // Parse expression
        if(expressionText != null) {
            compilationContext.lineLocator = LineLocators.createLineLocator(expressionText);
            try {
                switch(parseMode) {
                case BLOCK: {
                    SCLBlockParser parser = new SCLBlockParser(new StringReader(expressionText));
                    parser.parseCommands();
                    expression = parser.block;
                } break;
                case EXPRESSION: {
                    SCLParserImpl parser = new SCLParserImpl(new StringReader(expressionText));
                    expression = (Expression)parser.parseExp();
                } break;
                case EQUATION_BLOCK: {
                    SCLParserImpl parser = new SCLParserImpl(new StringReader(expressionText));
                    SCLParserOptions parserOptions = new SCLParserOptions();
                    parserOptions.supportEq = true;
                    parser.setParserOptions(parserOptions);
                    expression = (Expression)parser.parseEquationBlock();
                } break;
                }
            } catch(SCLSyntaxErrorException e) {
                errorLog.log(e.location, e.getMessage());
                //LOGGER.info(errorLog.getErrorsAsString());
                throw new SCLExpressionCompilationException(errorLog.getErrors());
            } catch(Exception e) {
                errorLog.log(e);
                throw new SCLExpressionCompilationException(errorLog.getErrors());
            }
        }
        else
            compilationContext.lineLocator = LineLocators.DUMMY_LOCATOR;
        
        // Store local variables
        ArrayList<Type> lvTypes = new ArrayList<Type>(); 
        if(expression instanceof EBlock) {
            EBlock block = (EBlock)expression;
            if(localStorage != null && !(block.getLast() instanceof GuardStatement)) {
                THashSet<String> localVariables = new THashSet<String>();
                ListIterator<Statement> it = block.getStatements().listIterator();
                while(it.hasNext()) {
                    Statement stat = it.next();
                    if(!(stat instanceof LetStatement))
                        continue;
                    String variableName;
                    try {
                        variableName = ((LetStatement)stat).pattern.getPatternHead().name;
                    } catch (NotPatternException e) {
                        continue;
                    }
                    localVariables.add(variableName);
                }
                for(String variableName : localVariables) {
                    Type type = Types.metaVar(Kinds.STAR);
                    lvTypes.add(type);
                    block.addStatement(new GuardStatement(new EApply(
                            new EExternalConstant(
                                    new StoreFunction(localStorage, variableName, type),
                                    Types.functionE(type, Types.PROC, Types.UNIT)),
                                    new EVar(variableName)
                            )));
                    if(validateOnly)
                        localStorage.store(variableName, null, type);
                }
            }
            if(!(block.getLast() instanceof GuardStatement))
                block.addStatement(new GuardStatement(new EConstant(Builtins.TUPLE_CONSTRUCTORS[0])));
        }
        
        // Elaboration
        {
            TranslationContext context = new TranslationContext(compilationContext, localEnvironment, "expression");
            expression = expression.resolve(context);
            if(!errorLog.hasNoErrors())
                throw new SCLExpressionCompilationException(errorLog.getErrors());
        }
        
        // Apply local environment
        if(localEnvironment != null) {
            expression = localEnvironment.preDecorateExpression(expression);
            ProcedureType procedureType = localEnvironment.decorateExpectedType(expectedType, expectedEffect);
            expectedType = procedureType.type;
            expectedEffect = procedureType.effect;
        }
        
        // Type checking
        {
            TypingContext context = new TypingContext(compilationContext);

            context.pushEffectUpperBound(expression.location, expectedEffect);
            expression = expression.checkType(context, expectedType);
            context.popEffectUpperBound();

            for(Type lvType : lvTypes)
                lvType.addPolarity(Polarity.POSITIVE);
            
            expectedType.addPolarity(Polarity.POSITIVE);
            context.solveSubsumptions(expression.location);
            if(!errorLog.hasNoErrors())
                throw new SCLExpressionCompilationException(errorLog.getErrors());
            if(decorateExpression && Types.canonical(expectedEffect) != Types.NO_EFFECTS) {
                ToplevelEffectDecorator decorator =
                        new ToplevelEffectDecorator(errorLog, environment);
                expression = expression.accept(decorator);
            }
            expression = context.solveConstraints(environment, expression);
            expressionType = expression.getType();
            
            if(!errorLog.hasNoErrors())
                throw new SCLExpressionCompilationException(errorLog.getErrors());

            if(localEnvironment != null)
                expression = localEnvironment.postDecorateExpression(expression);
            
            if(validateOnly)
                return null;

            Type type = expression.getType();
            type = type.convertMetaVarsToVars();
            
            for(Type lvType : lvTypes)
                lvType.convertMetaVarsToVars();
            
            ArrayList<TVar> varsList = Types.freeVars(type);
            expression = expression.closure(varsList.toArray(new TVar[varsList.size()]));
        }
        
        // Initialize code generation
        MutableClassLoader classLoader = runtimeEnvironment.getMutableClassLoader();
        String moduleName = classLoader.getFreshPackageName();
        JavaTypeTranslator javaTypeTranslator = new JavaTypeTranslator(environment);
        compilationContext.javaTypeTranslator = javaTypeTranslator;
        JavaNamingPolicy namingPolicy = new JavaNamingPolicy(moduleName);
        compilationContext.namingPolicy = namingPolicy;

        ModuleBuilder moduleBuilder = new ModuleBuilder(namingPolicy, javaTypeTranslator);
        
        // Simplify
        SimplificationContext context = 
                new SimplificationContext(compilationContext, DummyJavaReferenceValidator.INSTANCE);
        expression = expression.simplify(context);
        
        if(!errorLog.hasNoErrors())
            throw new SCLExpressionCompilationException(errorLog.getErrors());
        
        if(SCLCompilerConfiguration.SHOW_EXPRESSION_BEFORE_EVALUATION)
            LOGGER.info("{}", expression);
        
        if(interpretIfPossible) {
        // Try to interpret
        try {
            ExpressionInterpretationContext expressionInterpretationContext =
                    new ExpressionInterpretationContext(runtimeEnvironment, 
                            new TransientClassBuilder(classLoader, javaTypeTranslator));
            IExpression iexp = expression.toIExpression(expressionInterpretationContext);
                if(TRACE_INTERPRETATION_VS_COMPILATION)
                LOGGER.info("INTERPRETED " + expressionText);
                if(SCLCompilerConfiguration.SHOW_INTERPRETED_EXPRESSION)
                    LOGGER.info("INTERPRETED AS: " + iexp);
            return iexp.execute(new Object[expressionInterpretationContext.getMaxVariableId()]);
        } catch(UnsupportedOperationException e) {
            // This is normal when expression cannot be interpreted. We compile it instead.
        }
        }
        
        // Convert to SSA
        ModuleWriter mw = new ModuleWriter(namingPolicy.getModuleClassName(), compilationContext.lineLocator);
        DecomposedExpression decomposed = 
                DecomposedExpression.decompose(errorLog, expression);

        SCLConstant constant = new SCLConstant(
                Name.create(moduleName, COMPUTATION_METHOD_NAME),
                expression.getType());
        constant.setBase(new JavaStaticMethod(
                moduleName, COMPUTATION_METHOD_NAME,
                decomposed.effect,
                decomposed.typeParameters,
                decomposed.returnType, 
                decomposed.parameterTypes));
        try {
            CodeWriter w = mw.createFunction(constant,
                    decomposed.typeParameters,
                    decomposed.effect,
                    decomposed.returnType, 
                    decomposed.parameterTypes);
            constant.setDefinition(w.getFunction());
            IVal[] parameterVals = w.getParameters();
            for(int i=0;i<decomposed.parameters.length;++i)
                decomposed.parameters[i].setVal(parameterVals[i]);
            w.return_(decomposed.body.location, decomposed.body.toVal(compilationContext, w));
        } catch(RuntimeException e) {
            errorLog.setExceptionPosition(expression.location);
            errorLog.log(e);
            throw new SCLExpressionCompilationException(errorLog.getErrors());
        }

        SSAModule ssaModule = mw.getModule();
        if(SCLCompilerConfiguration.SHOW_SSA_BEFORE_OPTIMIZATION) {
            LOGGER.info("=== SSA before optimization ==================================");
            LOGGER.info("{}", ssaModule);
        }
        if(SCLCompilerConfiguration.DEBUG)
            ssaModule.validate();

        ExternalConstant[] externalConstants = mw.getExternalConstants();
        
        // Optimize SSA
        for(int phase=0;phase<CodeGeneration.OPTIMIZATION_PHASES;++phase) {
            int optCount = 0;
            while(optCount++ < 4 && ssaModule.simplify(environment, phase)) {
                //LOGGER.info("simplify " + optCount);
            }
        }
        if(SCLCompilerConfiguration.SHOW_SSA_BEFORE_LAMBDA_LIFTING) {
            LOGGER.info("=== SSA before lambda lifting ==================================");
            LOGGER.info("{}", ssaModule);
        }
        //ssaModule.saveInlinableDefinitions();
        ssaModule.lambdaLift(errorLog);
        //ssaModule.validate();
        ssaModule.markGenerateOnFly();
        
        // Generate code
        if(SCLCompilerConfiguration.SHOW_FINAL_SSA)
            LOGGER.info("{}", ssaModule);
        try {
            ssaModule.generateCode(moduleBuilder);
        } catch (CodeBuildingException e) {
            errorLog.log(e);
            throw new SCLExpressionCompilationException(errorLog.getErrors());
        }
        Map<String, byte[]> classes = moduleBuilder.getClasses();
        ssaModule.cleanup();
        
        // Load generated code and execute
        try {
            classLoader.addClasses(classes);
            Class<?> clazz = classLoader.loadClass(MutableClassLoader.SCL_PACKAGE_PREFIX + moduleName);
            for(ExternalConstant externalConstant : externalConstants)
                clazz.getField(externalConstant.fieldName).set(null, externalConstant.value);
            for(Method method : clazz.getMethods()) {
                if(method.getName().equals(COMPUTATION_METHOD_NAME))
                    return ValueFromMethod.getValueFromStaticMethod(method);
            }
            errorLog.log("Internal compiler error: didn't find method " +
                    COMPUTATION_METHOD_NAME + " from generated byte code.");
            throw new SCLExpressionCompilationException(errorLog.getErrors());
        } catch(ReflectiveOperationException e) {
            errorLog.log(e);
            throw new SCLExpressionCompilationException(errorLog.getErrors());
        }
    }

    public Type getType() {
        return expressionType;
    }

    public String getExpressionText() {
        return expressionText;
    }
    
}
