package org.simantics.scl.compiler.types;

import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.internal.types.HashCodeUtils;
import org.simantics.scl.compiler.types.exceptions.KindUnificationException;
import org.simantics.scl.compiler.types.exceptions.UnificationException;
import org.simantics.scl.compiler.types.kinds.Kind;
import org.simantics.scl.compiler.types.kinds.Kinds;

import gnu.trove.map.hash.THashMap;

public class Skeletons {
    
    public static Type canonicalSkeleton(Type type) {
        while(type instanceof TMetaVar) {
            TMetaVar metaVar = (TMetaVar)type;
            if(metaVar.ref != null)
                type = metaVar.ref;
            else if(metaVar.skeletonRef != null)
                return metaVar.skeletonRef = canonicalSkeleton(metaVar.skeletonRef);
            else
                return metaVar;
        }
        return type;
    }
    
    public static Type canonicalSkeleton(THashMap<TMetaVar,Type> unifications, Type type) {
        while(type instanceof TMetaVar) {
            TMetaVar metaVar = (TMetaVar)type;
            if(metaVar.ref != null)
                type = metaVar.ref;
            else if(metaVar.skeletonRef != null)
                type = metaVar.skeletonRef;
            else {
                Type temp = unifications.get(metaVar);
                if(temp == null)
                    return metaVar;
                else
                    type = temp;
            }
        }
        return type;
    }
    
    public static boolean doesSkeletonContain(THashMap<TMetaVar,Type> unifications, Type type, TMetaVar metaVar) {
        type = canonicalSkeleton(unifications, type);
        if(type == metaVar)
            return true;
        if(type instanceof TFun) {
            TFun fun = (TFun)type;
            return doesSkeletonContain(unifications, fun.domain, metaVar)
                    || doesSkeletonContain(unifications, fun.range, metaVar); 
        }
        if(type instanceof TApply) {
            TApply apply = (TApply)type;
            return doesSkeletonContain(unifications, apply.function, metaVar)
                    || doesSkeletonContain(unifications, apply.parameter, metaVar); 
        }
        if(type instanceof TForAll) {
            TForAll forAll = (TForAll)type;
            return doesSkeletonContain(unifications, forAll.type, metaVar); 
        }
        if(type instanceof TPred) {
            TPred pred = (TPred)type;
            for(Type param : pred.parameters)
                if(doesSkeletonContain(unifications, param, metaVar))
                    return true;
            return false;
        }
        else
            return false;
    }

    /**
     * Returns true, if unification of the skeletons of the types would succeed.
     */
    public static boolean areSkeletonsCompatible(THashMap<TMetaVar,Type> unifications, Type a, Type b) {
        a = canonicalSkeleton(unifications, a);
        b = canonicalSkeleton(unifications, b);
        if(a == b)
            return true;
        Class<?> ca = a.getClass();
        Class<?> cb = b.getClass();
        
        if(ca == TMetaVar.class) {
            TMetaVar ma = (TMetaVar)a;
            if(doesSkeletonContain(unifications, b, ma))
                return false;
            unifications.put(ma, b);
            return true;
        }
        if(cb == TMetaVar.class) {
            TMetaVar mb = (TMetaVar)b;
            if(doesSkeletonContain(unifications, a, mb))
                return false;
            unifications.put(mb, a);
            return true;
        }
        if(ca != cb)
            return false;
        if(ca == TFun.class) {
            TFun funA = (TFun)a;
            TFun funB = (TFun)b;
            return areSkeletonsCompatible(unifications, funA.domain, funB.domain)
                    && areSkeletonsCompatible(unifications, funA.range, funB.range);
        }
        if(ca == TApply.class) {
            TApply applyA = (TApply)a;
            TApply applyB = (TApply)b;
            return areSkeletonsCompatible(unifications, applyA.function, applyB.function)
                    && areSkeletonsCompatible(unifications, applyA.parameter, applyB.parameter);
        }
        if(ca == TPred.class) {
            TPred predA = (TPred)a;
            TPred predB = (TPred)b;
            if(predA.typeClass != predB.typeClass)
                return false;
            for(int i=0;i<predA.parameters.length;++i)
                if(!areSkeletonsCompatible(unifications, predA.parameters[i], predB.parameters[i]))
                    return false;
            return true;
        }
        if(ca == TForAll.class) {
            TForAll forAllA = (TForAll)a;
            TForAll forAllB = (TForAll)b;
            TVar temp = Types.var(forAllA.var.getKind());
            return areSkeletonsCompatible(unifications,
                    forAllA.type.replace(forAllA.var, temp),
                    forAllB.type.replace(forAllB.var, temp));
        }
        return false;
    }

    public static void unifySkeletons(Type a, Type b) throws UnificationException {
        a = canonicalSkeleton(a);
        b = canonicalSkeleton(b);
        
        if(a == b)
            return;
        if(a instanceof TMetaVar) {
            ((TMetaVar) a).setSkeletonRef(b);
            return;
        }
        if(b instanceof TMetaVar) {
            ((TMetaVar) b).setSkeletonRef(a);
            return;
        }
        
        Class<?> ca = a.getClass();
        Class<?> cb = b.getClass();
        if(ca != cb) {
            throw new UnificationException(a, b);
        }
        if(ca == TApply.class) 
            //unifySkeletons((TApply)a, (TApply)b);
            Types.unify(a, b);
        else if(ca == TFun.class) 
            unifySkeletons((TFun)a, (TFun)b);
        else if(ca == TForAll.class)
            unifySkeletons((TForAll)a, (TForAll)b);
        else if(ca == TPred.class) 
            //unifySkeletons((TPred)a, (TPred)b);
            Types.unify(a, b);
        else if(ca == TUnion.class) 
            unifySkeletons((TUnion)a, (TUnion)b);
        else // ca == TCon.class || ca = TVar.class 
            throw new UnificationException(a, b);
    }
    
    public static void unifySkeletons(TFun a, TFun b) throws UnificationException {
        unifySkeletons(a.domain, b.domain);
        unifySkeletons(a.range, b.range);
    }

    public static void unifySkeletons(TApply a, TApply b) throws UnificationException {
        unifySkeletons(a.function, b.function);
        unifySkeletons(a.parameter, b.parameter);
    }

    public static void unifySkeletons(TForAll a, TForAll b) throws UnificationException {
        try {
            Kinds.unify(a.var.getKind(), b.var.getKind());
        } catch (KindUnificationException e) {
            throw new UnificationException(a, b);
        }
        TVar newVar = Types.var(a.var.getKind());
        unifySkeletons(a.type.replace(a.var, newVar), b.type.replace(b.var, newVar));
    }

    public static void unifySkeletons(TPred a, TPred b) throws UnificationException {
        if(a.typeClass != b.typeClass
                || a.parameters.length != b.parameters.length)
            throw new UnificationException(a, b);
        for(int i=0;i<a.parameters.length;++i)
            unifySkeletons(a.parameters[i], b.parameters[i]);
    }

    public static void unifySkeletons(TUnion a, TUnion b) throws UnificationException {
        // Nothing to do
    }
    
    public static Type commonSkeleton(Environment context, Type[] types) {
        THashMap<Type[], TMetaVar> metaVarMap = new THashMap<Type[], TMetaVar>() {
            @Override
            protected boolean equals(Object a, Object b) {
                return Types.equals((Type[])a, (Type[])b);
            }
            @Override
            protected int hash(Object a) {
                Type[] types = (Type[])a;
                int hash = HashCodeUtils.SEED;
                for(Type type : types)
                    hash = type.hashCode(hash);
                return hash;
            }
        };
        return commonSkeleton(context, metaVarMap, types);
    }

    private static TMetaVar metaVarFor(Environment context, THashMap<Type[], TMetaVar> metaVarMap, Type[] types) {
        TMetaVar result = metaVarMap.get(types);
        if(result == null) {
            try {
                result = Types.metaVar(types[0].inferKind(context));
            } catch (KindUnificationException e) {
                result = Types.metaVar(Kinds.STAR);
            }
            metaVarMap.put(types, result);
        }
        return result;
    }
    
    /**
     * Finds the most specific type that can be unified with the all the types
     * given as a parameter.
     */
    private static Type commonSkeleton(Environment context, THashMap<Type[], TMetaVar> metaVarMap, Type[] types) {
        for(int i=0;i<types.length;++i)
            types[i] = canonicalSkeleton(types[i]);

        Type first = types[0];
        Class<?> clazz = first.getClass();
        for(int i=1;i<types.length;++i)
            if(types[i].getClass() != clazz)
                return metaVarFor(context, metaVarMap, types);

        if(clazz == TCon.class) {
            for(int i=1;i<types.length;++i)
                if(types[i] != first)
                    return metaVarFor(context, metaVarMap, types);
            return first;
        }
        else if(clazz == TApply.class) {
            Type[] functions = new Type[types.length];
            Type[] parameters = new Type[types.length];
            for(int i=0;i<types.length;++i) {
                TApply apply = (TApply)types[i];
                functions[i] = apply.function;
                parameters[i] = apply.parameter;
            }
            return Types.apply(
                    commonSkeleton(context, metaVarMap, functions),
                    commonSkeleton(context, metaVarMap, parameters));
        }
        else if(clazz == TFun.class) {
            Type[] domains = new Type[types.length];
            Type[] effects = new Type[types.length];
            Type[] ranges = new Type[types.length];
            for(int i=0;i<types.length;++i) {
                TFun fun = (TFun)types[i];
                if(fun.domain instanceof TPred)
                    return metaVarFor(context, metaVarMap, types);
                domains[i] = fun.domain;
                effects[i] = fun.effect;
                ranges[i] = fun.range;
            }
            return Types.functionE(
                    commonSkeleton(context, metaVarMap, domains),
                    commonEffect(effects),
                    commonSkeleton(context, metaVarMap, ranges));
        }
        else
            return metaVarFor(context, metaVarMap, types);
    }

    private static Type commonEffect(Type[] effects) {
        Type first = effects[0];
        for(int i=1;i<effects.length;++i)
            if(!Types.equals(first, effects[i]))
                return Types.metaVar(Kinds.EFFECT);
        return first;
    }
    
    public static boolean equalSkeletons(TApply a, TApply b) {
        return equalSkeletons(a.parameter, b.parameter)
                && equalSkeletons(a.function , b.function );
    }

    public static boolean equalSkeletons(TFun a, TFun b) {
        return equalSkeletons(a.domain, b.domain)
                && equalSkeletons(a.range, b.range);
    }
    
    public static boolean equalSkeletons(TForAll a, TForAll b) {
        Kind aKind = a.var.getKind();
        if(!Kinds.equalsCanonical(aKind, b.var.getKind()))
            return false;
        TVar newVar = Types.var(aKind);
        return equalSkeletons(a.type.replace(a.var, newVar), b.type.replace(b.var, newVar));
    }

    public static boolean equalSkeletons(TPred a, TPred b) {
        if(a.typeClass != b.typeClass 
                || a.parameters.length != b.parameters.length)
            return false;
        Type[] aParameters = a.parameters;
        Type[] bParameters = b.parameters;
        for(int i=0;i<aParameters.length;++i)
            if(!equalSkeletons(aParameters[i], bParameters[i]))
                return false;
        return true;
    }

    /**
     * Tests equality of two types. Unbound TVars
     * are equal only if they are the same variable.
     * Bound TMetaVar is equal to the type it is bound to.
     * Unbound TMetaVars are equal only if they are the same metavariable.
     * Order of predicates and forall quantifiers matters.
     */
    public static boolean equalSkeletons(Type a, Type b) {
        a = canonicalSkeleton(a);
        b = canonicalSkeleton(b);
        if(a == b)
            return true;
        Class<?> ca = a.getClass();
        Class<?> cb = b.getClass();
        if(ca != cb)
            return false;
        if(ca == TApply.class) 
            return equalSkeletons((TApply)a, (TApply)b);
        else if(ca == TFun.class) 
            return equalSkeletons((TFun)a, (TFun)b);
        else if(ca == TForAll.class)
            return equalSkeletons((TForAll)a, (TForAll)b);
        else if(ca == TPred.class) 
            return equalSkeletons((TPred)a, (TPred)b);       
        else // ca == TCon.class 
            // || (ca == TMetaVar.class && a.ref == null && b.ref == null) 
            // || ca = TVar.class 
            return false; // Equals only if a == b, that was already tested
    }
}
