package org.simantics.scl.compiler.internal.elaboration.constraints;

import java.util.ArrayList;

import org.simantics.scl.compiler.types.TApply;
import org.simantics.scl.compiler.types.TCon;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.util.MultiApply;

import gnu.trove.map.hash.THashMap;
import gnu.trove.procedure.TObjectObjectProcedure;

public class InstanceTree<T> {

    Node<T> root;
    
    private static interface Node<T> {
        T get(ArrayList<Type> types);        
    }
    
    private static class SplitNode<T> implements Node<T> {
        int pId;
        THashMap<TCon, Node<T>> map;
        Node<T> alternative;
        
        public SplitNode(int pId, THashMap<TCon, Node<T>> map, Node<T> alternative) {
            this.pId = pId;
            this.map = map;
            this.alternative = alternative;
        }

        @Override
        public T get(ArrayList<Type> types) {
            MultiApply apply = Types.matchApply(types.get(pId));
            Node<T> node = map.get(apply.constructor);
            if(node == null)
                return null;
            for(Type parameter : apply.parameters)
                types.add(parameter);
            return node.get(types);
        }
    }
    
    private static class Entry<T> {
        Type[] types;
        T value;
    }
    
    private static boolean isIndexable(Type type) {
        type = Types.canonical(type);
        while(true) {
            if(type instanceof TCon)
                return true;
            else if(type instanceof TApply)
                type = Types.canonical(((TApply)type).function);
            else
                return false;
        }
    }
    
    private static <T> int choosePId(ArrayList<Entry<T>> entries) {
        int arity = entries.get(0).types.length;
        if(arity == 1)
            return 0;
        int[] indexableCount = new int[arity];
        for(Entry<T> entry : entries)
            for(int i=0;i<arity;++i)
                if(isIndexable(entry.types[i]))
                    ++indexableCount[i];
        int bestIndexableCount = indexableCount[0];
        int bestPId = 0;
        for(int i=1;i<arity;++i)
            if(indexableCount[i] > bestIndexableCount) {
                bestIndexableCount = indexableCount[i];
                bestPId = i;
            }
        return bestPId;
    }
        
    private static <T> Node<T> create(ArrayList<Entry<T>> entries) {
        int pId = choosePId(entries);
        THashMap<TCon, ArrayList<Entry<T>>> map1 = new THashMap<TCon, ArrayList<Entry<T>>>();
        ArrayList<Entry<T>> otherEntries = new ArrayList<Entry<T>>(); 
        for(Entry<T> entry : entries) {
            Type[] types = entry.types;
            Type type = types[pId];
            MultiApply apply = Types.matchApply(type);
            if(apply.constructor instanceof TCon) {
                ArrayList<Entry<T>> l = map1.get((TCon)apply.constructor);
                if(l == null) {
                    l = new ArrayList<Entry<T>>();
                    map1.put((TCon)apply.constructor, l);
                }
                Type[] newTypes = new Type[types.length-1+apply.parameters.length];
                int j=0;
                for(int i=0;i<pId;++i)
                    newTypes[j++] = types[i];
                for(int i=0;i<apply.parameters.length;++i)
                    newTypes[j++] = apply.parameters[i];
                for(int i=pId+1;i<types.length;++i)
                    newTypes[j++] = types[i];
                entry.types = newTypes;
                l.add(entry);
            }
            else {
                otherEntries.add(entry);
            }
        }
        final THashMap<TCon, Node<T>> map = new THashMap<TCon, Node<T>>();
        map1.forEachEntry(new TObjectObjectProcedure<TCon, ArrayList<Entry<T>>>() {
            @Override
            public boolean execute(TCon a, ArrayList<Entry<T>> b) {
                map.put(a, create(b));
                return true;
            }
        });
        return new SplitNode<T>(pId, map, create(otherEntries));
    }
    
    public T get(ArrayList<Type> types) {
        return root.get(types);
    }
}
