/*******************************************************************************
 * Copyright (c) 2007, 2010 Association for Decentralized Information Management
 * in Industry THTH ry.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *     VTT Technical Research Centre of Finland - initial API and implementation
 *******************************************************************************/
package org.simantics.utils.datastructures.persistent;

import gnu.trove.procedure.TObjectProcedure;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

public class ImmutableSet<T extends Comparable<T>> {
	
	@SuppressWarnings({ "rawtypes" })
	static final ImmutableSet EMPTY = new EmptyImmutableSet();
	
	private static class EmptyImmutableSet<T extends Comparable<T>> extends ImmutableSet<T> {
		public EmptyImmutableSet() {
			isBlack = true;
		}
		
		protected org.simantics.utils.datastructures.persistent.ImmutableSet<T> addRec(T obj) {
			return new ImmutableSet<T>(obj);
		}
		
		protected org.simantics.utils.datastructures.persistent.ImmutableSet<T> removeRec(T obj) {
			return null;
		}
		
		public boolean contains(T obj) {
			return false;
		}
		
		@Override
		public boolean forEach(TObjectProcedure<T> arg0) {
			return true;
		}
		
		@Override
		void print(int arg0) {		
		}
	}
	
	ImmutableSet<T> left;
	T key;
	ImmutableSet<T> right;
	boolean isBlack;
	
	protected ImmutableSet(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right,			
			boolean isBlack) {
		this.left = left;
		this.right = right;
		this.key = key;
		this.isBlack = isBlack;
	}

	@SuppressWarnings("unchecked")
	public ImmutableSet(T key) {
		this(EMPTY, key, EMPTY, false);
	}
	
	@SuppressWarnings("unchecked")
	public static <T extends Comparable<T>> ImmutableSet<T> of(T ... array) {
		if(array.length == 0)
			return EMPTY;
		ImmutableSet<T> ret = new ImmutableSet<T>(array[0]);
		for(int i=1;i<array.length;++i)
			ret = ret.add(array[i]);
		return ret;
	}
	
	@SuppressWarnings("unchecked")
	public static <T extends Comparable<T>> ImmutableSet<T> of(Collection<T> c) {
		Iterator<T> it = c.iterator();
		if(!it.hasNext())
			return EMPTY;		
		ImmutableSet<T> ret = new ImmutableSet<T>(it.next());
		while(it.hasNext())
			ret = ret.add(it.next());
		return ret;
	}

	private ImmutableSet() {	
	}
	
	public boolean contains(T obj) {
		int cmp = obj.compareTo(key);
		if(cmp < 0)
			return left.contains(obj);
		else if(cmp > 0)
			return right.contains(obj);
		else
			return true;
	}	
	
	protected ImmutableSet<T> addRec(T obj) {
		int cmp = obj.compareTo(key);
		if(cmp < 0) {
			ImmutableSet<T> newLeft = left.addRec(obj);
			if(newLeft == left)
				return this;			
			if(isBlack)
				return balance(newLeft, key, right);
			else
				return red(newLeft, key, right);
		}
		else if(cmp > 0) {
			ImmutableSet<T> newRight = right.addRec(obj);
			if(newRight == right)
				return this;
			if(isBlack)
				return balance(left, key, newRight);
			else
				return red(left, key, newRight);
		}
		else
			return this;
	}
	
	/**
	 * Assumes this is a black nonempty node. 
	 * 
	 * Removes obj from the tree. The returned tree has always
	 * one black node less in every branch (even if nothing is
	 * removed).
	 *   
	 * @param obj
	 * @return
	 */
	protected ImmutableSet<T> removeRec(T obj) {		
		int cmp = obj.compareTo(key);
		if(cmp < 0) {
			ImmutableSet<T> newLeft = left.removeRec(obj);
			if(newLeft == null)
				return null;
			else if(left.isBlack)
				return balleft(newLeft, key, right);
			else
				return red(newLeft, key, right);
		}
		else if(cmp > 0) {
			ImmutableSet<T> newRight = right.removeRec(obj);
			if(newRight == null)
				return null;
			else if(right.isBlack)
				return balright(left, key, newRight);
			else
				return red(left, key, newRight);			
		}
		else
			return append(left, right);
	}
	
	/**
	 * Assumes a and b have the same black depth and keys in a are strictly less
	 * than keys in b. Creates a new tree with the same black depth as a and b. 
	 * 
	 * @param <T>
	 * @param a
	 * @param b
	 * @return
	 */
	protected static <T extends Comparable<T>> ImmutableSet<T> append(ImmutableSet<T> a, ImmutableSet<T> b) {
		if(a==EMPTY)
			return b;
		if(b==EMPTY)
			return a;
		if(a.isBlack) {
			if(b.isBlack) {
				ImmutableSet<T> middle = append(a.right, b.left);
				if(middle.isBlack)
					return balleft(a.left, a.key, black(middle, b.key, b.right));
				else
					return red(black(a.left, a.key, middle.left), middle.key, black(middle.right, b.key, b.right));
			}
			else
				return red(append(a, b.left), b.key, b.right);
		}
		else {
			if(b.isBlack)
				return red(a.left, a.key, append(a.right, b));
			else {
				ImmutableSet<T> middle = append(a.right, b.left);
				if(middle.isBlack)
					return red(a.left, a.key, red(middle, b.key, b.right));
				else
					return red(red(a.left, a.key, middle.left), middle.key, red(middle.right, b.key, b.right));
			}
		}
	}
	
	public T getFirst() {
		if(left == EMPTY)
			return key;
		else
			return left.getFirst();
	}	
	
	static private <T extends Comparable<T>> ImmutableSet<T> balance(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right) {
		if(!left.isBlack) {
			if(!left.left.isBlack) 
				return red(
					toBlack(left.left),
					left.key,
					black(left.right, key, right)
				);
			else if(!left.right.isBlack)
				return red(
					black(left.left, left.key, left.right.left),
					left.right.key,
					black(left.right.right, key, right)					
				);				
		}
		if(!right.isBlack) {
			if(!right.left.isBlack)
				return red(
					black(left, key, right.left.left),
					right.left.key,
					black(right.left.right, right.key, right.right)
				);
			else if(!right.right.isBlack)
				return red(
					black(left, key, right.left),
					right.key,
					toBlack(right.right)
				);		
		}
		return black(left, key, right);
	}
	
	static private <T extends Comparable<T>> ImmutableSet<T> black(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right) {
		return new ImmutableSet<T>(left, key, right, true);
	}
	
	static private <T extends Comparable<T>> ImmutableSet<T> red(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right) {
		return new ImmutableSet<T>(left, key, right, false);
	}
	
	static private <T extends Comparable<T>> ImmutableSet<T> toBlack(ImmutableSet<T> tree) {
		if(tree.isBlack)
			return tree;
		else
			return black(tree.left, tree.key, tree.right);
	}
	
	static private <T extends Comparable<T>> ImmutableSet<T> toRed(ImmutableSet<T> tree) {
		if(tree.isBlack)
			return red(tree.left, tree.key, tree.right);
		else
			return tree;
	}
			
	
	static private <T extends Comparable<T>> ImmutableSet<T> balleft(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right) {
		if(left.isBlack) {
			if(right.isBlack)
				return balance(left, key, toRed(right));
			else
				return red(black(left, key, right.left.left), right.left.key, balance(right.left.right, right.key, toRed(right.right)));
		}
		else
			return red(toBlack(left), key, right);
	}
	
	static private <T extends Comparable<T>> ImmutableSet<T> balright(
			ImmutableSet<T> left,
			T key,
			ImmutableSet<T> right) {
		if(right.isBlack) {
			if(left.isBlack)
				return balance(toRed(left), key, right);
			else
				return red(balance(toRed(left.left), left.key, left.right.left), left.right.key, black(left.right.right, key, right));
		}
		else
			return red(left, key, toBlack(right));
	}

	public ImmutableSet<T> add(T obj) {
		ImmutableSet<T> tree = addRec(obj);
		tree.isBlack = true;
		return tree;
	}

	public ImmutableSet<T> remove(T obj) {
		ImmutableSet<T> tree = removeRec(obj);
		if(tree == null)
			return this;
		if(tree.isBlack)
			return tree;
		else
			return black(tree.left, tree.key, tree.right);
	}

	public boolean forEach(TObjectProcedure<T> procedure) {
		if(left.forEach(procedure))
			if(procedure.execute(key))
				return right.forEach(procedure);
		return false;
	}
	
	public ImmutableSet<T> addAll(Iterable<T> c) {
		ImmutableSet<T> ret = this;
		for(T t : c)
			ret = ret.add(t);
		return ret;
	}
	
	static class AddAll<T extends Comparable<T>> implements TObjectProcedure<T> {

		ImmutableSet<T> result;		
		
		public AddAll(ImmutableSet<T> result) {
			this.result = result;
		}

		@Override
		public boolean execute(T arg0) {
			result = result.add(arg0);
			return true;
		}
		
	}
	
	public ImmutableSet<T> addAll(ImmutableSet<T> c) {
		if(this==EMPTY)
			return c;
		if(c==EMPTY)
			return this;
		AddAll<T> proc = new AddAll<T>(this);
		c.forEach(proc);
		return proc.result;
	}
	
	/**************************************************************************
	 * 
	 *  Testing
	 * 
	 ************************************************************************** 
	 */
	
	void print(int indentation) {
		left.print(indentation + 1);
		for(int i=0;i<indentation;++i)
			System.out.print("    ");
		System.out.println(key + " " + (isBlack ? "BLACK" : "RED"));
		right.print(indentation + 1);
	}
	
	private int validateRec() {
		if(this==EMPTY)
			return 1;
		int lh = left.validateRec();
		int rh = right.validateRec();
		if(lh != rh)
			System.out.println("Unbalanced!");
		if(!isBlack) {
			if(!left.isBlack || !right.isBlack)
				System.out.println("Red under red");
			return lh;
		}
		else
			return lh+1;
	}	

	@SuppressWarnings("unchecked")
	public static <T extends Comparable<T>> ImmutableSet<T> empty() {
		return EMPTY;
	}
	
	public void validate() {
		validateRec();
		if(!isBlack)
			System.out.println("Root is not black");
	}

	@SuppressWarnings("unchecked")
	public static void main(String[] args) {
		ImmutableSet<Integer> set = ImmutableSet.EMPTY;
		final Set<Integer> set2 = new HashSet<Integer>();

		Random random = new Random();
		for(int i=0;i<10000;++i) {
			int r1 = random.nextInt(1000);
			int r2 = random.nextInt(1000);
			set = set.add(r1);
			set = set.remove(r2);
			set2.add(r1);
			set2.remove(r2);
			set.validate();			
			
			for(Integer v : set2)
				if(!set.contains(v))
					System.out.println("set2 has more elements");
			set.forEach(new TObjectProcedure<Integer>() {

				@Override
				public boolean execute(Integer arg0) {
					if(!set2.contains(arg0))
						System.out.println("set has more elements");
					return true;
				}
				
			});
		}
					
		/*
		map.forEach(new TObjectProcedure<Integer>() {

			@Override
			public boolean execute(Integer arg0) {
				System.out.println(arg0);
				return true;
			}
			
		});
		*/
	}
	
}
