package org.simantics.graph.matching;

import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.TIntHashSet;

import java.util.Arrays;
import java.util.Comparator;

import org.simantics.databoard.binding.mutable.Variant;

public enum CanonicalizingMatchingStrategy implements GraphMatchingStrategy {
	INSTANCE;

	private static class Vertex {
		int graph;
		int original;
		int pos;
		Stat[] stats;
		
		public Vertex(int graph, int original, int pos, Stat[] stats) {
			this.graph = graph;
			this.original = original;
			this.pos = pos;
			this.stats = stats;
		}
	}
	
	private static final Comparator<Vertex> VERTEX_COMPARATOR = new Comparator<Vertex>() {
		@Override
		public int compare(Vertex o1, Vertex o2) {
			int pos1 = o1.pos;
			int pos2 = o2.pos;
			if(pos1 < pos2)
				return -1;
			if(pos1 > pos2)
				return 1;
			Stat[] stats1 = o1.stats;
			Stat[] stats2 = o2.stats;
			if(stats1.length < stats2.length)
				return -1;
			if(stats1.length > stats2.length)
				return 1;
			for(int i=0;i<stats1.length;++i) {
				int comp = Stat.STAT_COMPARATOR.compare(stats1[i], stats2[i]);
				if(comp != 0)
					return comp;
			}
			if(o1.graph < o2.graph)
				return -1;
			if(o1.graph > o2.graph)
				return 1;			
			if(o1.original < o2.original)
				return -1;
			if(o1.original > o2.original)
				return 1;
			return 0;
		}
	};
	
	private static int[] generateMapA(int[] aToB) {
		int[] map = new int[aToB.length];
		for(int i=0;i<aToB.length;++i) {
			int c = aToB[i];
			if(c >= 0)
				map[i] = -1 - c;
			else
				map[i] = 0;
		}
		return map;
	}
	
	private static int[] generateMapB(int[] bToA) {
		int[] map = new int[bToA.length];
		for(int i=0;i<bToA.length;++i) {
			int c = bToA[i];
			if(c >= 0)
				map[i] = -1 - i;
			else
				map[i] = 0;
		}
		return map;
	}
	
	private static Vertex[] generateVertices(int graph, int[] map, Stat[][] statements) {
		int size = 0;
		for(int s=0;s<map.length;++s)
			if(map[s] == 0)
				++size;
		Vertex[] vertices = new Vertex[size];
		for(int s=0,i=0;s<map.length;++s)
			if(map[s] == 0) {
				Stat[] ns = statements[s];
				Stat[] stats = new Stat[ns.length];
				for(int j=0;j<ns.length;++j) {
					Stat n = ns[j];
					stats[j] = new Stat(map[n.p], map[n.o]);
				}
				Arrays.sort(stats, Stat.STAT_COMPARATOR);
				vertices[i++] = new Vertex(graph, s, 0, stats);
			}
		return vertices;
	}
	
	private static void updateVertices(Vertex[] vertices, int[] map, Stat[][] statements) {
		for(int i=0;i<vertices.length;++i) {
			int s = vertices[i].original;
			Stat[] ns = statements[s];
			Stat[] stats = vertices[i].stats;
			for(int j=0;j<ns.length;++j) {
				Stat n = ns[j];
				Stat stat = stats[j];
				stat.p = map[n.p];
				stat.o = map[n.o];
			}
			Arrays.sort(stats, Stat.STAT_COMPARATOR);
		}
	}
	
	private static Vertex[] concat(Vertex[] as, Vertex[] bs) {
		Vertex[] result = new Vertex[as.length + bs.length];
		System.arraycopy(as, 0, result, 0, as.length);
		System.arraycopy(bs, 0, result, as.length, bs.length);
		return result;
	}
	
	static boolean equals(Stat[] stats1, Stat[] stats2) {
		if(stats1.length != stats2.length)
			return false;
		for(int i=0;i<stats1.length;++i) {
			Stat stat1 = stats1[i];
			Stat stat2 = stats2[i];
			if(stat1.p != stat2.p || stat1.o != stat2.o)
				return false;
		}
		return true;
	}
	
	private static boolean updatePositions(Vertex[] can) {
		boolean modified = false;
		int oldPos = can[0].pos;
		Vertex oldVertex = can[0];
		for(int i=1;i<can.length;++i) {
			Vertex curVertex = can[i];
			int curPos = curVertex.pos;
			if(curPos == oldPos) {
				if(equals(oldVertex.stats, curVertex.stats))
					curVertex.pos = oldVertex.pos;
				else {
					curVertex.pos = i;
					modified = true;
				}
			}
			else
				oldPos = curPos;
			oldVertex = curVertex;
		}
		return modified;
	}
	
	private static void updateMap(Vertex[] vertices, int[] map) {
		for(Vertex vertex : vertices)
			map[vertex.original] = vertex.pos;
	}

	private static int[] groupPositions(Vertex[] can) {
		TIntArrayList result = new TIntArrayList();
		for(int i=0;i<can.length;++i)
			if(can[i].pos == i)
				result.add(i);
		result.add(can.length);
		return result.toArray();
	}	
		
	static class TByteArrayIntHashMap extends TObjectIntHashMap<byte[]> {
		@Override
		protected boolean equals(Object one, Object two) {
			return Arrays.equals((byte[])one, (byte[])two);
		}
		
		@Override
		protected int hash(Object obj) {
			return Arrays.hashCode((byte[])obj);
		}
	}
	
	private boolean separateByValues(Vertex[] can, int begin, int end, Variant[] aValues, Variant[] bValues) {		
		int valueCount = 0;
		TObjectIntHashMap<Variant> valueIds = new TObjectIntHashMap<Variant>();
		int[] ids = new int[end-begin];
		for(int i=begin;i<end;++i) {
			Vertex v = can[i];
			Variant value = v.graph==0 ? aValues[v.original] : bValues[v.original];
			int valueId = valueIds.get(value);
			if(valueId == 0) {
				valueIds.put(value, ++valueCount);
				ids[i-begin] = valueCount-1;
			}
			else
				ids[i-begin] = valueId-1;
		}
		if(valueCount > 1) {
			Vertex[] vs = Arrays.copyOfRange(can, begin, end);
			int[] temp = new int[valueCount];
			for(int id : ids)
				++temp[id];
			int cur = begin;
			for(int i=0;i<temp.length;++i) {
				int count = temp[i];
				temp[i] = cur;
				cur += count;
			}
			for(int i=0;i<ids.length;++i)
				vs[i].pos = temp[ids[i]];
			for(int i=0;i<ids.length;++i)
				can[temp[ids[i]]++] = vs[i];
			return true;
		}
		else
			return false;
	}
	
	private boolean separateByValues(Vertex[] can, int[] groupPos, Variant[] aValues, Variant[] bValues) {
		boolean modified = false;
		for(int i=0;i<groupPos.length-1;++i) {
			int begin = groupPos[i];
			int end = groupPos[i+1];
			if(end - begin > 2)
				modified |= separateByValues(can, begin, end, aValues, bValues);					
		}
		return modified;
	}
	
	private boolean hasBigGroups(Vertex[] can, int[] groupPos) {
		for(int i=0;i<groupPos.length-1;++i) {
			int begin = groupPos[i];
			int end = groupPos[i+1];
			if(end - begin > 2 && can[begin].graph == 0 && can[end-1].graph == 1)
				return true;
		}
		return false;
	}
	
	static class UnionFind {
		int[] canonical;
		
		public UnionFind(int size) {
			canonical = new int[size];
			for(int i=0;i<size;++i)
				canonical[i] = i;
		}
		
		public int canonical(int a) {
			int b = canonical[a];
			if(b != a) {
				int c = canonical[b];
				if(b != c) {
					c = canonical(c);
					canonical[b] = c;					
					canonical[a] = c;
					return c;
				}
			}
			return b;
		}
		
		public void merge(int a, int b) {
			a = canonical(a);
			b = canonical(b);
			canonical[a] = b;
		}
	}
	
	private static void guessIsomorphism(Vertex[] can, int[] groupPos) {
		UnionFind uf = new UnionFind(can.length);
		for(int i=0;i<can.length;++i) {
			uf.merge(i, can[i].pos);
			for(Stat stat : can[i].stats) {
				if(stat.p >= 0)
					uf.merge(i, stat.p);
				if(stat.o >= 0)
					uf.merge(i, stat.o);
			}
		}
		
		TIntHashSet done = new TIntHashSet();
		for(int i=0;i<groupPos.length-1;++i) {
			int begin = groupPos[i];
			int end = groupPos[i+1];
			if(end - begin > 2 && can[begin].graph == 0 && can[end-1].graph == 1) {
				int c = uf.canonical(begin);
				if(done.add(c)) {
					int middle = begin+1;
					while(can[middle].graph==0)
						++middle;
					int j;
					for(j=0;begin+j < middle && middle+j < end;++j) {
						can[begin+j].pos = begin + j*2;
						can[middle+j].pos = begin + j*2;
					}
					int pos = begin+j*2;					
					for(;begin+j < middle;++j)
						can[begin+j].pos = pos;
					for(;middle+j < end;++j)
						can[middle+j].pos = pos;
				}
			}
		}
	}
	
	@Override
	public void applyTo(GraphMatching matching) {
		if(matching.size == matching.aGraph.resourceCount ||
				matching.size == matching.bGraph.resourceCount)
			return;
		long begin1 = System.nanoTime();
		int[] aMap = generateMapA(matching.aToB);
		int[] bMap = generateMapB(matching.bToA);
		Vertex[] aVertices = generateVertices(0,
				aMap, matching.aGraph.statements);
		Vertex[] bVertices = generateVertices(1,
				bMap, matching.bGraph.statements);
		Vertex[] can = concat(aVertices, bVertices);
		if(GraphMatching.TIMING)
			System.out.println("    Init:    " + (System.nanoTime()-begin1)*1e-6 + "ms");
		
		int[] groupPos = null;
		boolean doneSeparationByValues = false;
		while(true) {
			long begin2 = System.nanoTime();
			Arrays.sort(can, VERTEX_COMPARATOR);
			if(GraphMatching.TIMING)
				System.out.println("    Sort:    " + (System.nanoTime()-begin2)*1e-6 + "ms");
			
			long begin3 = System.nanoTime();
			if(!updatePositions(can)) {			
				groupPos = groupPositions(can);				
				if(!hasBigGroups(can, groupPos))
					break;
				
				boolean modified = false;
				if(!doneSeparationByValues) {
					modified = separateByValues(can, groupPos, matching.aGraph.values, matching.bGraph.values);										
					doneSeparationByValues = true;
					if(GraphMatching.TIMING)
						System.out.println("    - separate by values");
				}
				
				if(!modified) {
					guessIsomorphism(can, groupPos);
					if(GraphMatching.TIMING)
						System.out.println("    - guess isomorphism");
				}
			}
			if(GraphMatching.TIMING)
				System.out.println("    Update1: " + (System.nanoTime()-begin3)*1e-6 + "ms");
			
			long begin4 = System.nanoTime();
			updateMap(aVertices, aMap);			
			updateMap(bVertices, bMap);
			if(GraphMatching.TIMING)
				System.out.println("    Update2: " + (System.nanoTime()-begin4)*1e-6 + "ms");			
			long begin5 = System.nanoTime();
			updateVertices(aVertices, aMap, matching.aGraph.statements);
			updateVertices(bVertices, bMap, matching.bGraph.statements);
			if(GraphMatching.TIMING)
				System.out.println("    Update3: " + (System.nanoTime()-begin5)*1e-6 + "ms");
		}
		
		for(int i=0;i<groupPos.length-1;++i) {
			int begin = groupPos[i];
			int end = groupPos[i+1];
			if(end - begin == 2) {
				Vertex a = can[begin];
				Vertex b = can[end-1];
				if(a.graph == 0 && b.graph == 1)
					matching.map(a.original, b.original);
			}
		}
		
		if(GraphMatching.DEBUG)
			for(int i=0;i<groupPos.length-1;++i) {
				int begin = groupPos[i];
				int end = groupPos[i+1];
				if(end - begin > 2) {				
					System.out.println("----");
					for(int j=begin;j<end;++j) {
						if(can[j].graph == 0) {
							int org = can[j].original;
							String name = matching.aGraph.names[org];
							System.out.println(name + " (A)");
							for(Stat stat : matching.aGraph.statements[org])
								System.out.println("    " + stat.toString(matching.aGraph.names));
							Variant value = matching.aGraph.values[org];
							if(value != null)
								System.out.println("    " + value);
						}
						else {
							int org = can[j].original;
							String name = matching.bGraph.names[org];
							System.out.println(name + " (B)");
							for(Stat stat : matching.bGraph.statements[org])
								System.out.println("    " + stat.toString(matching.bGraph.names));
							Variant value = matching.bGraph.values[org];
							if(value != null)
								System.out.println("    " + value);
						}
					}
				}
			}
	}
}
