package org.simantics.utils.datastructures.collections;

import java.awt.geom.Point2D;
import java.awt.geom.Rectangle2D;
import java.util.HashSet;
import java.util.Set;

public class QuadTree<T> {
	
	Point2D center;
	Set<T> contains;
	double width;
	double height;
	
	boolean leaf;
	QuadTree<T> pXpY;
	QuadTree<T> nXpY;
	QuadTree<T> pXnY;
	QuadTree<T> nXnY;
	
	/**
	 * Creates a quadtree
	 * @param center center of the quadtree area
	 * @param width width of the area
	 * @param height height of the area
	 * @param depth depth of the tree structure. 
	 */
	public QuadTree(Point2D center, double width, double height, int depth) {
		this.center = center;
		this.width = width;
		this.height = height;
		this.leaf = true;
		split(depth);
	}
	
	private QuadTree(Point2D center, double width, double height) {
		this.center = center;
		this.width = width;
		this.height = height;
		this.leaf = true;
	}
	
	/**
	 * Insets an object into the tree
	 * @param object
	 * @param bounds
	 */
	public void insert(T object, Rectangle2D bounds) {
		if (leaf) {
			contains.add(object);
		} else {
			
			boolean pX = bounds.getMinX() > center.getX();
			boolean nX = bounds.getMaxX() < center.getX();
			boolean pY = bounds.getMinY() > center.getY();
			boolean nY = bounds.getMaxY() < center.getY();
			
			if (pX) {
				if (pY) {
					pXpY.insert(object, bounds);
				} else if (nY) {
					pXnY.insert(object, bounds);
				} else {
					pXpY.insert(object, bounds);
					pXnY.insert(object, bounds);
				}
			} else if (nX) {
				if (pY) {
					nXpY.insert(object, bounds);
				} else if (nY) {
					nXnY.insert(object, bounds);
				} else {
					nXpY.insert(object, bounds);
					nXnY.insert(object, bounds);
				}
			} else if (pY) {
				pXpY.insert(object, bounds);
				nXpY.insert(object, bounds);
			} else if (nY) {
				pXnY.insert(object, bounds);
				nXnY.insert(object, bounds);
			} else {
				pXpY.insert(object, bounds);
				pXnY.insert(object, bounds);
				nXpY.insert(object, bounds);
				nXnY.insert(object, bounds);
			}
			
		}
	}
	
	/**
	 * Removes an object from the tree
	 * @param object
	 */
	public void remove(T object) {
		if (leaf) {
			contains.remove(object);
		} else {
			pXpY.remove(object);
			pXnY.remove(object);
			nXpY.remove(object);
			nXnY.remove(object);
		}
	}
	
	/**
	 * Returns objects within the given area.
	 * @param bounds
	 * @param set
	 */
	public void get(Rectangle2D bounds, Set<T> set) {
		if (leaf) {
			set.addAll(contains);
		} else {
			
			boolean pX = bounds.getMinX() > center.getX();
			boolean nX = bounds.getMaxX() < center.getX();
			boolean pY = bounds.getMinY() > center.getY();
			boolean nY = bounds.getMaxY() < center.getY();
			
			if (pX) {
				if (pY) {
					pXpY.get(bounds, set);
				} else if (nY) {
					pXnY.get(bounds, set);
				} else {
					pXpY.get(bounds, set);
					pXnY.get(bounds, set);
				}
			} else if (nX) {
				if (pY) {
					nXpY.get(bounds, set);
				} else if (nY) {
					nXnY.get(bounds, set);
				} else {
					nXpY.get(bounds, set);
					nXnY.get(bounds, set);
				}
			} else if (pY) {
				pXpY.get(bounds, set);
				nXpY.get(bounds, set);
			} else if (nY) {
				pXnY.get(bounds, set);
				nXnY.get(bounds, set);
			} else {
				pXpY.get(bounds, set);
				pXnY.get(bounds, set);
				nXpY.get(bounds, set);
				nXnY.get(bounds, set);
			}
			
		}
	}
	
	public void clear() {
		if (leaf) {
			contains.clear();
		} else {
			pXpY.clear();
			pXnY.clear();
			nXpY.clear();
			nXnY.clear();
		}
	}
	
	private void split(int depth) {
		if (!leaf) {
			throw new IllegalStateException("Node is already split");
		}
		if (depth <= 0) {
			this.contains = new HashSet<>();
			return;
		}
		split();
		depth--;
		pXpY.split(depth);
		nXpY.split(depth);
		pXnY.split(depth);
		nXnY.split(depth);
	}
	
	private void split() {
		double w = width * 0.5;
		double h = height * 0.5;
		double wd2 = w * 0.5;
		double hd2 = h * 0.5;
		pXpY = new QuadTree<T>(new Point2D.Double(center.getX()+wd2, center.getY()+hd2), w, h);
		nXpY = new QuadTree<T>(new Point2D.Double(center.getX()-wd2, center.getY()+hd2), w, h);
		pXnY = new QuadTree<T>(new Point2D.Double(center.getX()+wd2, center.getY()-hd2), w, h);
		nXnY = new QuadTree<T>(new Point2D.Double(center.getX()-wd2, center.getY()-hd2), w, h);
		leaf = false;
	}
	

}
