package org.simantics.scl.compiler.internal.codegen.utils;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.internal.codegen.continuations.Cont;
import org.simantics.scl.compiler.internal.codegen.continuations.ContRef;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.references.ValRef;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.util.TypeUnparsingContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TObjectIntProcedure;
import gnu.trove.set.hash.THashSet;

public class SSAValidationContext {

    private static final Logger LOGGER = LoggerFactory.getLogger(SSAValidationContext.class);

    public THashSet<BoundVar> validBoundVariables = new THashSet<BoundVar>();
    public THashSet<Cont> validContinuations = new THashSet<Cont>();
    public THashSet<TVar> validTypeVariables = new THashSet<TVar>();
    public TObjectIntHashMap<Val> refCount = new TObjectIntHashMap<Val>(); 
    public Object errorMarker = null;
    
    public void assertEquals(Object loc, Type a, Type b) {
        if(!Types.equals(a, b)) {
            TypeUnparsingContext tuc = new TypeUnparsingContext();
            String message = a.toString(tuc) + " != " + b.toString(tuc);
            LOGGER.error(message);
            setErrorMarker(loc);
            throw new InternalCompilerError(message);
        }
    }
    
    public void assertSubsumes(Object loc, Type a, Type b) {
        if(!Types.subsumes(a, b)) {
            TypeUnparsingContext tuc = new TypeUnparsingContext();
            /*System.err.println(a.toString(tuc) + " <! " + b.toString(tuc));*/
            setErrorMarker(loc);
            throw new InternalCompilerError(a.toString(tuc) + " <! " + b.toString(tuc));
        }
    }

    public void assertEqualsEffect(Object loc, Type a, Type b) {
        if(!Types.equalsEffect(a, b)) {
            TypeUnparsingContext tuc = new TypeUnparsingContext();
            String message = a.toString(tuc) + " != " + b.toString(tuc);
            LOGGER.error(message);
            setErrorMarker(loc);
            throw new InternalCompilerError(message);
        }
    }
    
    public void assertEquals(int a, int b) {
        if(a != b)
            throw new InternalCompilerError();
    }

    public void reset() {
        validContinuations.clear();
        validTypeVariables.clear();
    }

    public void validate(Cont cont) {
        for(int i=0;i<cont.getArity();++i)
            validateType(cont.getParameterType(i));
    }
    
    public void validate(Val val) {
        validateType(val.getType());
    }
    
    private static boolean hasOccurrence(Cont cont, ContRef occ) {
        for(ContRef ref = cont.getOccurrence(); 
                ref != null; 
                ref = ref.getNext())
            if(ref == occ)
                return true;
        return false;
    }
    
    public void validate(ContRef ref) {
        if(!validContinuations.contains(ref.getBinding()))
            throw new InternalCompilerError();
        if(!hasOccurrence(ref.getBinding(), ref))
            throw new InternalCompilerError();
        if(ref.getParent() == null)
            throw new InternalCompilerError();
    }
    
    boolean invalidReferenceCounts;
    
    public void checkReferences() {
        invalidReferenceCounts = false;
        refCount.forEachEntry(new TObjectIntProcedure<Val>() {            
            @Override
            public boolean execute(Val val, int count) {
                if(val instanceof Constant) {
                    if(!(val instanceof SCLConstant))
                        return true;
                    if(!((SCLConstant)val).getName().module.equals("Composition"))
                        return true;
                }
                
                int realCount = val.occurrenceCount();
                if(realCount != count) {
                    LOGGER.warn(val + ": " + realCount + " != " + count);
                    invalidReferenceCounts = true;
                }
                return true;
            }
        });
        if(invalidReferenceCounts)
            throw new InternalCompilerError();
    }
    
    public void validate(ValRef ref) {
        refCount.adjustOrPutValue(ref.getBinding(), 1, 1);
        
        Val val = ref.getBinding();
        if(val == null)
            throw new InternalCompilerError();
        if(val instanceof Constant)
            return;
        if(!validBoundVariables.contains(val))
            throw new InternalCompilerError();
        
        if(ref.getParent() == null)
            throw new InternalCompilerError();
    }

    public void validateType(Type type) {
        // PROBLEM: code involving existential data types do not pass this test
        /*for(TVar var : Types.freeVars(type))
            if(!validTypeVariables.contains(var))
                throw new InternalCompilerError();*/
    }
    
    public void setErrorMarker(Object errorMarker) {
        this.errorMarker = errorMarker;
    }
    
}
