/*******************************************************************************
 *  Copyright (c) 2010 Association for Decentralized Information Management in
 *  Industry THTH ry.
 *  All rights reserved. This program and the accompanying materials
 *  are made available under the terms of the Eclipse Public License v1.0
 *  which accompanies this distribution, and is available at
 *  http://www.eclipse.org/legal/epl-v10.html
 *
 *  Contributors:
 *      VTT Technical Research Centre of Finland - initial API and implementation
 *******************************************************************************/
package org.simantics.databoard.method;

import gnu.trove.map.hash.TObjectIntHashMap;

import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.simantics.databoard.Bindings;
import org.simantics.databoard.annotations.Union;
import org.simantics.databoard.binding.Binding;
import org.simantics.databoard.binding.RecordBinding;
import org.simantics.databoard.binding.UnionBinding;
import org.simantics.databoard.serialization.Serializer;
import org.simantics.databoard.serialization.SerializerConstructionException;
import org.simantics.databoard.util.binary.BinaryReadable;
import org.simantics.databoard.util.binary.BinaryWriteable;
import org.simantics.databoard.util.binary.InputStreamReadable;
import org.simantics.databoard.util.binary.OutputStreamWriteable;

/**
 * Connection is a class that handles request-response communication over a
 * socket. 
 * <p>
 * Requests have asynchronous result. The result can be acquired using one of 
 * the three methods:
 *  1) Blocking read AsyncResult.waitForResponse()
 *  2) Poll  AsyncResult.getResponse()
 *  3) Listen AsyncResult.setListener()
 * <p>
 * The socket must be established before Connection is instantiated.
 * Closing connection does not close its Socket.
 * If the socket is closed before connection there an error is thrown.
 * The error is available by placing listener.
 * The proper order to close a connection is to close Connection first
 * and then Socket. 
 *  
 * @author Toni Kalajainen <toni.kalajainen@vtt.fi>
 */
public class TcpConnection implements MethodInterface {

	public static final ExecutorService SHARED_EXECUTOR_SERVICE = 
		new ThreadPoolExecutor(0, Integer.MAX_VALUE, 100L, TimeUnit.MILLISECONDS, new SynchronousQueue<Runnable>());
	
	static final Serializer MESSAGE_SERIALIZER = Bindings.getSerializerUnchecked( Bindings.getBindingUnchecked(Message.class) );
	static Charset UTF8 = Charset.forName("UTF8");
	
	Handshake local, remote;
	
	Interface remoteType;
	MethodTypeDefinition[] localMethods, remoteMethods;
	HashMap<MethodTypeDefinition, Integer> localMethodsMap, remoteMethodsMap;
	
	Socket socket;
		
	// if false, there is an error in the socket or the connection has been shutdown 
	boolean active = true;

	// Objects used for handling local services 
	MethodInterface methodInterface;
		
	// Objects used for reading data
	ConcurrentHashMap<Integer, PendingRequest> requests = new ConcurrentHashMap<Integer, PendingRequest>();  
	List<Object> inIdentities = new ArrayList<Object>();
	BinaryReadable in;
	int maxRecvSize;
	
	// Object used for writing data
	public ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
	TObjectIntHashMap<Object> outIdentities = new TObjectIntHashMap<Object>();
	BinaryWriteable out;
	AtomicInteger requestCounter = new AtomicInteger(0);
	int maxSendSize;
	
	// Cached method descriptions
	Map<String, MethodType> methodTypes = new ConcurrentHashMap<String, MethodType>();
	
	/**
	 * Handshake a socket
	 * 
	 * @param socket 
	 * @param localData local data
	 * @return the remote data
	 * @throws IOException
	 * @throws RuntimeException unexpected error (BindingException or EncodingException) 
	 */
	public static Handshake handshake(final Socket socket, final Handshake localData)
	throws IOException
	{
		final BinaryReadable bin = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE );
		final BinaryWriteable bout = new OutputStreamWriteable( socket.getOutputStream() );
		ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
		
		// do hand-shake		
		final Exception[] writeError = new Exception[1];
		final Semaphore s = new Semaphore(0);		
		writeExecutor.execute(new Runnable() {
			@Override
			public void run() {
				try {
					TObjectIntHashMap<Object> outIdentities = new TObjectIntHashMap<Object>();					
					Handshake.SERIALIZER.serialize(bout, outIdentities, localData);
					bout.flush();
					outIdentities.clear();
				} catch (IOException e) {
					writeError[0] = e;
				} finally {
					s.release(1);
				}
			}});
		
		// Read remote peer's handshake 
		List<Object> inIdentities = new ArrayList<Object>();
		Handshake result = (Handshake) Handshake.SERIALIZER.deserialize(bin, inIdentities);
		inIdentities.clear();
		
		// Check that write was ok
		try {
			s.acquire(1);
			Exception e = writeError[0];
			if (e!=null && e instanceof IOException) 
				throw (IOException) e;
			if (e!=null)
				throw new RuntimeException(e);			
		} catch (InterruptedException e) {
			throw new RuntimeException(e);
		} finally {
//			writeExecutor.shutdown();
		}
		
		return result;
	}
	
	/**
	 * Create a connection to a hand-shaken socket 
	 * 
	 * @param socket
	 * @param methodInterface local method handler 
	 * @param localData
	 * @param remoteData
	 * @throws IOException
	 */
	public TcpConnection(Socket socket, MethodInterface methodInterface, Handshake localData, Handshake remoteData) 
	throws IOException {
		if (socket==null || localData==null || remoteData==null) 
			throw new IllegalArgumentException("null arg");
		
		this.methodInterface = methodInterface;
		this.socket = socket;
		this.local = localData;
		this.remote = remoteData;
		this.maxSendSize = Math.min(localData.sendMsgLimit, remoteData.recvMsgLimit);
		this.maxRecvSize = Math.min(localData.recvMsgLimit, remoteData.sendMsgLimit);
		
		this.localMethods = local.methods;
		this.remoteMethods = remote.methods;
		this.remoteType = new Interface(this.remoteMethods);
		this.localMethodsMap = new HashMap<MethodTypeDefinition, Integer>();
		this.remoteMethodsMap = new HashMap<MethodTypeDefinition, Integer>();
		for (int i=0; i<localMethods.length; i++)
			localMethodsMap.put(localMethods[i], i);
		for (int i=0; i<remoteMethods.length; i++)
			remoteMethodsMap.put(remoteMethods[i], i);
//		remoteMethodsMap.trimToSize();
//		localMethodsMap.trimToSize();
		
		in = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE );
		out = new OutputStreamWriteable( socket.getOutputStream() );
		
		String threadName = "Connection-"+socket.getInetAddress().getHostAddress()+":"+socket.getPort();
		
		thread.setName( threadName );
		thread.start();
	}

	@Override
	public Interface getInterface() {
		return remoteType;
	}
	
	@Override
	public Method getMethod(MethodTypeBinding binding)
	throws MethodNotSupportedException {
		// consumer suggests object bindings
		MethodTypeDefinition description = binding.getMethodDefinition();
	
		if (!remoteMethodsMap.containsKey(description)) {
	/*
			System.out.println("Method not found: "+description);
			System.out.println("Existing methods:" );
			for (MethodTypeDefinition k : remoteMethodsMap.keySet()) {
				System.out.print(k);
				if (k.getType().requestType.getComponentCount()>0) {
					System.out.print(System.identityHashCode( k.getType().requestType.getComponentType(0) ) );
				}
				System.out.println();					
			}
*/
			throw new MethodNotSupportedException(description.getName());
		}
		
		int id = remoteMethodsMap.get(description);
		
		try {
			return new MethodImpl(id, binding);
		} catch (SerializerConstructionException e) {
			throw new MethodNotSupportedException(e);
		}
	}
	
	@Override
	public Method getMethod(MethodTypeDefinition description)
	throws MethodNotSupportedException {
		// producer suggests object bindings
		if (!remoteMethodsMap.containsKey(description)) {
			throw new MethodNotSupportedException(description.getName());
		}		
		int id = remoteMethodsMap.get(description);
		
		RecordBinding reqBinding = (RecordBinding) Bindings.getMutableBinding(description.getType().getRequestType());
		Binding resBinding = Bindings.getMutableBinding(description.getType().getResponseType());
		UnionBinding errBinding = (UnionBinding) Bindings.getMutableBinding(description.getType().getErrorType());
		MethodTypeBinding binding = new MethodTypeBinding(description, reqBinding, resBinding, errBinding);
		
		try {
			return new MethodImpl(id, binding);
		} catch (SerializerConstructionException e) {
			// Generic binding should work
			throw new MethodNotSupportedException(e);
		}
	}
	
	public Socket getSocket()
	{
		return socket;
	}
	
	public interface ConnectionListener {
		/**
		 * There was an error and connection was closed
		 * 
		 * @param error
		 */
		void onError(Exception error);
		
		/**
		 * close() was invoked
		 */
		void onClosed();
	}
	
	CopyOnWriteArrayList<ConnectionListener> listeners = new CopyOnWriteArrayList<ConnectionListener>();
	
	public synchronized void addConnectionListener(ConnectionListener listener) {		
		listeners.add( listener );
	}
	
	public void removeConnectionListener(ConnectionListener listener) {
		listeners.remove( listener );
	}
	
	class MethodImpl implements Method {
		int methodId;
		MethodTypeBinding methodBinding;
		Serializer responseSerializer;
		Serializer requestSerializer;
		Serializer errorSerializer;
		
		MethodImpl(int methodId, MethodTypeBinding methodBinding) throws SerializerConstructionException
		{
			this.methodId = methodId;
			this.methodBinding = methodBinding;
			this.requestSerializer = Bindings.getSerializer( methodBinding.getRequestBinding() );
			this.responseSerializer = Bindings.getSerializer( methodBinding.getResponseBinding() );
			this.errorSerializer = Bindings.getSerializer( methodBinding.getErrorBinding() );
		}		
		
		@Override
		public AsyncResult invoke(final Object request) {
			// Write, async
			final PendingRequest result = new PendingRequest(this, requestCounter.getAndIncrement());
			requests.put(result.requestId, result);
						
			if (!active) {
				result.setInvokeException(new InvokeException(new ConnectionClosedException()));
			} else {			
				writeExecutor.execute(new Runnable() {
				@Override
				public void run() {			
				  synchronized(TcpConnection.this) {
					try {
						int size= requestSerializer.getSize(request, outIdentities);
						if (size>maxSendSize) {
							result.setInvokeException(new InvokeException(new MessageOverflowException()));
							return;
						}
						outIdentities.clear();

						RequestHeader reqHeader = new RequestHeader();
						reqHeader.methodId = methodId; 
						reqHeader.requestId = result.requestId;
						MESSAGE_SERIALIZER.serialize(out, outIdentities, reqHeader);
						outIdentities.clear();
						out.writeInt(size);
						requestSerializer.serialize(out, outIdentities, request);
						outIdentities.clear();
						out.flush();
					} catch (IOException e) {
						result.setInvokeException(new InvokeException(e));
					} catch (RuntimeException e) {
						result.setInvokeException(new InvokeException(e));
					}
				  }
				}});
			}
			return result;
		}

		@Override
		public MethodTypeBinding getMethodBinding() {
			return methodBinding;
		}
	}
	
	void setClosed() 
	{
		for (ConnectionListener listener : listeners)
			listener.onClosed();
	}	
	void setError(Exception e) 
	{
		for (ConnectionListener listener : listeners)
			listener.onError(e);
		close();		
	}
	
	/**
	 * Get method interface that handles services locally (service requests by peer)
	 * 
	 * @return local method interface
	 */
	public MethodInterface getLocalMethodInterface()
	{
		return methodInterface;
	}
	
	/**
	 * Get method interface that handles services locally (service requests by peer)
	 * 
	 * @return local method interface
	 */
	public MethodTypeDefinition[] getLocalMethodDescriptions()
	{
		return localMethods;
	}	
	
	public MethodInterface getRemoteMethodInterface() {
		return this;
	}
	
	/**
	 * Close the connection. All pending service request are canceled.
	 * The socket is not closed.
	 */
	public void close() {
		active = false;
		// cancel all pending requests
		ArrayList<PendingRequest> reqs = new ArrayList<PendingRequest>(requests.values());
		for (PendingRequest pr : reqs) {
			pr.setInvokeException(new InvokeException(new ConnectionClosedException()));
		}
		requests.values().removeAll(reqs);
		// shutdown inthread
		thread.interrupt();	
//		for (ConnectionListener listener : listeners)
//			listener.onClosed();
	}
	
	/**
	 * Get the active connection of current thread
	 * 
	 * @return Connection or <code>null</code> if current thread does not run connection
	 */
	public static TcpConnection getCurrentConnection() {
		Thread t = Thread.currentThread();
		if (t instanceof ConnectionThread == false) return null;
		ConnectionThread ct = (ConnectionThread) t;
		return ct.getConnection();
	}
	
	/**
	 * Connection Thread deserializes incoming messages from the TCP Stream.
	 *
	 */
	class ConnectionThread extends Thread {
		public ConnectionThread() {
			setDaemon(true);
		}
		
		public TcpConnection getConnection() {
			return TcpConnection.this;
		}
		
		public void run() {
			while (!Thread.interrupted()) {
				try {
					Message header = (Message) MESSAGE_SERIALIZER.deserialize(in, inIdentities);
					if (header instanceof RequestHeader) {
						final RequestHeader reqHeader = (RequestHeader) header;

						int size = in.readInt();
						if (size>maxRecvSize) {
							setError(new MessageOverflowException());
							return;
						}
						
						int methodId = reqHeader.methodId;
						if (methodId<0||methodId>=localMethods.length) {
							setError(new Exception("ProtocolError"));
							return;
						}
						MethodTypeDefinition methodDescription = localMethods[methodId];
						// Let back-end determine bindings
						try {
							final Method method = methodInterface.getMethod(methodDescription);
							final MethodTypeBinding methodBinding = method.getMethodBinding();
							// Deserialize payload						
							final Object request = Bindings.getSerializerUnchecked(methodBinding.getRequestBinding()).deserialize(in, inIdentities);
							inIdentities.clear();
								
							// Invoke method
							method.invoke(request).setListener(new InvokeListener() {
								@Override
								public void onCompleted(final Object response) {
										// Write RESP
										writeExecutor.execute(new Runnable() {
											@Override
											public void run() {
											  synchronized(TcpConnection.this) {
												try {
													Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getResponseBinding());
													int size = serializer.getSize(response, outIdentities);
													outIdentities.clear();
													if (size > maxSendSize) {
														ResponseTooLargeError tooLarge = new ResponseTooLargeError();
														tooLarge.requestId = reqHeader.requestId;
														MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);
														outIdentities.clear();													
														return;
													}

													ResponseHeader respHeader = new ResponseHeader();
													respHeader.requestId = reqHeader.requestId;
													MESSAGE_SERIALIZER.serialize(out, outIdentities, respHeader);
													outIdentities.clear();
													out.writeInt(size);
													
													serializer.serialize(out, outIdentities, response);
													outIdentities.clear();
													out.flush();
												} catch (IOException e) {
													setError(e);
												} catch (RuntimeException e) {
													setError(e);
												}
											  }
											}});
								}
								@Override
								public void onException(final Exception cause) {
										// Write ERRO
										writeExecutor.execute(new Runnable() {
											@Override
											public void run() {
											  synchronized(TcpConnection.this) {
												try {
													Exception_ msg = new Exception_();
													msg.message = cause.getClass().getName()+": "+cause.getMessage(); 
													
													MESSAGE_SERIALIZER.serialize(out, outIdentities, msg);
													outIdentities.clear();
													out.flush();
												} catch (IOException e) {
													setError(e);
												} catch (RuntimeException e) {
													setError(e);
												}
											  }
											}});									
									}
								@Override
								public void onExecutionError(final Object error) {
										// Write ERRO
										writeExecutor.execute(new Runnable() {
											@Override
											public void run() {
											  synchronized(TcpConnection.this) {
												try {
													Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getErrorBinding());
													int size = serializer.getSize(error, outIdentities);
													outIdentities.clear();
													
													if (size > maxSendSize) {
														ResponseTooLargeError tooLarge = new ResponseTooLargeError();
														tooLarge.requestId = reqHeader.requestId;
														MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);
														outIdentities.clear();													
														return;
													}
													
													ExecutionError_ errorHeader = new ExecutionError_();
													errorHeader.requestId = reqHeader.requestId;
													MESSAGE_SERIALIZER.serialize(out, outIdentities, errorHeader);
													outIdentities.clear();
													out.writeInt(size);
													serializer.serialize(out, outIdentities, error);
													outIdentities.clear();
													out.flush();
												} catch (IOException e) {
													setError(e);
												} catch (RuntimeException e) {
													setError(e);
												}
											  }
								}});			
							}});

						} catch (MethodNotSupportedException e) {
							in.skipBytes(size);
							// return with an error
							final InvalidMethodError error = new InvalidMethodError();
							error.requestId = reqHeader.requestId;
							writeExecutor.execute(new Runnable() {
								@Override
								public void run() {
								  synchronized(TcpConnection.this) {
									try {
										MESSAGE_SERIALIZER.serialize(out, outIdentities, error);
										outIdentities.clear();
										out.flush();
									} catch (IOException e) {
										setError(e);
									} catch (RuntimeException e) {
										setError(e);
									}
								  }
								}});							
							continue;								
						} 
								
	
					} else if (header instanceof ResponseHeader) {
						int requestId = ((ResponseHeader)header).requestId;
						PendingRequest req = requests.remove(requestId);
						if (req==null) {
							setError(new RuntimeException("Request by id "+requestId+" does not exist"));
							return;						
						}						
						int size = in.readInt();
						if (size>maxRecvSize) {
							// TODO SOMETHING
						}
						Object response = req.method.responseSerializer.deserialize(in, inIdentities);
						inIdentities.clear();
						req.setResponse(response);
					} else if (header instanceof ExecutionError_) {						
						int requestId = ((ExecutionError_)header).requestId;
						PendingRequest req = requests.remove(requestId);
						if (req==null) {
							setError(new RuntimeException("Request by id "+requestId+" does not exist"));
							return;
						}
						int size = in.readInt();
						if (size>maxRecvSize) {
							// TODO SOMETHING
						}
						Object executionError = req.method.errorSerializer.deserialize(in, inIdentities);
						inIdentities.clear();
						req.setExecutionError(executionError);
					} else if (header instanceof Exception_) {
						int requestId = ((Exception_)header).requestId;
						PendingRequest req = requests.remove(requestId);
						req.setExecutionError(new Exception(((Exception_)header).message));
					} else if (header instanceof InvalidMethodError) {
						int requestId = ((InvalidMethodError)header).requestId;
						PendingRequest req = requests.remove(requestId);
						req.setInvokeException(new InvokeException(new MethodNotSupportedException("?")));
					} else if (header instanceof ResponseTooLargeError) {
						int requestId = ((ResponseTooLargeError)header).requestId;
						PendingRequest req = requests.remove(requestId);
						req.setInvokeException(new InvokeException(new MessageOverflowException()));
					}
					
				} catch (EOFException e) {
					setClosed();
					break;
				} catch (SocketException e) {
					if (e.getMessage().equals("Socket Closed"))
						setClosed();
					else
						setError(e);
					break;
				} catch (IOException e) {
					setError(e);
					break;
				}
			}
			try {
				socket.close();
			} catch (IOException e) {
			}
			// Close pending requests
			close();
		};
	}

	// Thread that reads input data
	ConnectionThread thread = new ConnectionThread();
	
	class PendingRequest extends AsyncResultImpl {

		MethodImpl method;
		
		// request id
		int requestId;
		
		public PendingRequest(MethodImpl method, int requestId) {
			this.method = method; 
			this.requestId = requestId;
		}		
	}
	

	@Union({RequestHeader.class, ResponseHeader.class, ExecutionError_.class, Exception_.class, InvalidMethodError.class, ResponseTooLargeError.class})
	public static class Message {}

	public static class RequestHeader extends Message {
		public int requestId;
		public int methodId;		
		// Request Object
		public RequestHeader() {}
	}

	public static class ResponseHeader extends Message {
		public int requestId;
		// Response object
		public ResponseHeader() {}
	}

	// Error while invoking a method
	public static class ExecutionError_ extends Message {
		public int requestId;
		// Error object
		public ExecutionError_() {}
	}

	// MethodName does not exist
	public static class InvalidMethodError extends Message {
		public int requestId;
		public InvalidMethodError() {}
	}

	// Exception, not in method but somewhere else 
	public static class Exception_ extends Message {
		public int requestId;
		public String message;
		public Exception_() {}
	}

	public static class ResponseTooLargeError extends Message {
		public int requestId;
		public ResponseTooLargeError() {}
	}
	
}

