/*
* JBoss, Home of Professional Open Source
* Copyright 2005, JBoss Inc., and individual contributors as indicated
* by the @authors tag. See the copyright.txt in the distribution for a
* full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/

package org.jboss.remoting.transport.bisocket;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.net.Socket;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;

import org.jboss.logging.Logger;
import org.jboss.remoting.Client;
import org.jboss.remoting.ConnectionFailedException;
import org.jboss.remoting.InvocationRequest;
import org.jboss.remoting.InvokerLocator;
import org.jboss.remoting.invocation.InternalInvocation;
import org.jboss.remoting.marshal.Marshaller;
import org.jboss.remoting.marshal.UnMarshaller;
import org.jboss.remoting.transport.BidirectionalClientInvoker;
import org.jboss.remoting.transport.socket.SocketClientInvoker;
import org.jboss.remoting.transport.socket.SocketWrapper;

/**
 * The bisocket transport, an extension of the socket transport, is designed to allow
 * a callback server to function behind a firewall.  All connections are created by
 * a Socket constructor or factory on the client side connecting to a ServerSocket on
 * the server side.  When a callback client invoker on the server side needs to
 * open a connection to the callback server, it requests a connection by sending a
 * request message over a control connection to the client side.
 *
 * Because all connections are created in one direction, the bisocket transport is
 * asymmetric, in the sense that client invokers and server invokers behave differently
 * on the client side and on the server side.
 *
 *
 *
 * @author <a href="mailto:ron.sigal@jboss.com">Ron Sigal</a>
 */
public class BisocketClientInvoker
extends SocketClientInvoker
implements BidirectionalClientInvoker
{
   private static final Logger log = Logger.getLogger(BisocketClientInvoker.class);
   private static Map listenerIdToClientInvokerMap = Collections.synchronizedMap(new HashMap());
   private static Map listenerIdToCallbackClientInvokerMap = Collections.synchronizedMap(new HashMap());
   private static Map listenerIdToSocketsMap = new HashMap();
   private static Timer timer;

   protected String listenerId;

   private int pingFrequency = Bisocket.PING_FREQUENCY_DEFAULT;
   private int maxRetries = Bisocket.MAX_RETRIES_DEFAULT;
   private InvokerLocator secondaryLocator;
   private Socket controlSocket;
   private OutputStream controlOutputStream;
   private Object controlLock = new Object();
   private PingTimerTask pingTimerTask;
   protected boolean isCallbackInvoker;


   /**
    * @param listenerId
    * @return
    */
   static BisocketClientInvoker getBisocketClientInvoker(String listenerId)
   {
      return (BisocketClientInvoker) listenerIdToClientInvokerMap.get(listenerId);
   }


   static BisocketClientInvoker getBisocketCallbackClientInvoker(String listenerId)
   {
      return (BisocketClientInvoker) listenerIdToCallbackClientInvokerMap.get(listenerId);
   }
   
   
   static void removeBisocketClientInvoker(String listenerId)
   {
      listenerIdToClientInvokerMap.remove(listenerId);
   }


   static void transferSocket(String listenerId, Socket socket)
   {
      Set sockets = null;

      synchronized (listenerIdToSocketsMap)
      {
         sockets = (Set) listenerIdToSocketsMap.get(listenerId);
         if (sockets == null)
         {
            sockets = new HashSet();
            listenerIdToSocketsMap.put(listenerId, sockets);
         }
      }

      synchronized (sockets)
      {
         sockets.add(socket);
         sockets.notify();
      }
   }


   public BisocketClientInvoker(InvokerLocator locator) throws IOException
   {
      this(locator, null);
   }


   public BisocketClientInvoker(InvokerLocator locator, Map config) throws IOException
   {
      super(locator, config);

      if (config != null)
      {
         listenerId = (String) config.get(Client.LISTENER_ID_KEY);
         if (listenerId != null)
         {
            isCallbackInvoker = true;
            listenerIdToCallbackClientInvokerMap.put(listenerId, this);
            log.debug("registered " + listenerId + " -> " + this);
         }

         // look for pingFrequency param
         Object val = config.get(Bisocket.PING_FREQUENCY);
         if (val != null)
         {
            try
            {
               int nVal = Integer.valueOf((String) val).intValue();
               pingFrequency = nVal;
               log.debug("Setting ping frequency to: " + pingFrequency);
            }
            catch (Exception e)
            {
               log.warn("Could not convert " + Bisocket.PING_FREQUENCY +
                     " value of " + val + " to an int value.");
            }
         }
         
         val = configuration.get(Bisocket.MAX_RETRIES);
         if (val != null)
         {
            try
            {
               int nVal = Integer.valueOf((String) val).intValue();
               maxRetries = nVal;
               log.debug("Setting retry limit: " + maxRetries);
            }
            catch (Exception e)
            {
               log.warn("Could not convert " + Bisocket.MAX_RETRIES +
                     " value of " + val + " to an int value.");
            }
         }
      }
      
      if (isCallbackInvoker)
      {
         Set sockets = null;

         synchronized (listenerIdToSocketsMap)
         {
            sockets = (Set) listenerIdToSocketsMap.get(listenerId);
            if (sockets == null)
            {
               sockets = new HashSet();
               listenerIdToSocketsMap.put(listenerId, sockets);
            }
         }
         
         synchronized (sockets)
         {
            if (sockets.isEmpty())
            {
               try
               {
                  sockets.wait(timeout);
               }
               catch (InterruptedException ignored)
               {
                  log.warn("unexpected interrupt");
                  throw new InterruptedIOException("Attempt to create control socket interrupted");
               }
            }
            
            if (sockets.isEmpty())
               throw new IOException("Timed out trying to create control socket");
            
            Iterator it = sockets.iterator();
            controlSocket = (Socket) it.next();
            it.remove();
            controlOutputStream = controlSocket.getOutputStream();
            log.debug("got control socket: " + controlSocket);
            pingTimerTask = new PingTimerTask(this);
            if (timer == null)
            {
               timer = new Timer(true);
            }
            timer.schedule(pingTimerTask, pingFrequency, pingFrequency);
         }
      }
   }


   public int getPingFrequency()
   {
      return pingFrequency;
   }


   public void setPingFrequency(int pingFrequency)
   {
      this.pingFrequency = pingFrequency;
   }


   protected void handleConnect() throws ConnectionFailedException
   {
      // Callback client on server side.
      if (isCallbackInvoker)
      {
         // Bisocket callback client invoker doesn't share socket pools because of the danger
         // that two distinct callback servers could have the same "artifical" port.
         pool = new LinkedList();
         return;
      }

      // Client on client side.
      super.handleConnect();
      
      try
      {
         secondaryLocator = getSecondaryLocator();
      }
      catch (Throwable e)
      {
         log.error("Unable to retrieve address/port of secondary server socket", e);
         throw new ConnectionFailedException(e.getMessage());
      }
   }
   
   
   protected void handleDisconnect()
   {
      if (listenerId != null)
      {
         if (isCallbackInvoker)
         {
            listenerIdToCallbackClientInvokerMap.remove(listenerId);
            for (Iterator it = pool.iterator(); it.hasNext();)
            {
               SocketWrapper socketWrapper = (SocketWrapper) it.next();
               try
               {
                  socketWrapper.close();
               }
               catch (Exception ignored)
               {
               }
            }
         }
         else
         {
            listenerIdToClientInvokerMap.remove(listenerId);
            super.handleDisconnect();
         }

         listenerIdToSocketsMap.remove(listenerId);
         if (pingTimerTask != null)
            pingTimerTask.shutDown();
      }
      else
      {
         super.handleDisconnect();
      }
   }


   protected Object transport(String sessionId, Object invocation, Map metadata,
                              Marshaller marshaller, UnMarshaller unmarshaller)
   throws IOException, ConnectionFailedException, ClassNotFoundException
   {
      if (invocation instanceof InvocationRequest)
      {
         InvocationRequest ir = (InvocationRequest) invocation;
         Object o = ir.getParameter();
         if (o instanceof InternalInvocation)
         {
            InternalInvocation ii = (InternalInvocation) o;
            if (InternalInvocation.ADDLISTENER.equals(ii.getMethodName())
                && ir.getLocator() != null) // getLocator() == null for pull callbacks
            {
               Map requestPayload = ir.getRequestPayload();
               listenerId = (String) requestPayload.get(Client.LISTENER_ID_KEY);
               listenerIdToClientInvokerMap.put(listenerId, this);
               BisocketServerInvoker callbackServerInvoker;
               callbackServerInvoker = BisocketServerInvoker.getBisocketServerInvoker(listenerId);
               callbackServerInvoker.createControlConnection(listenerId, secondaryLocator);
            }
            
            // Rather than handle the REMOVELISTENER case symmetrically, it is
            // handled when a REMOVECLIENTLISTENER message is received by
            // BisocketServerInvoker.handleInternalInvocation().  The reason is that
            // if the Client executes removeListener() with disconnectTimeout == 0, 
            // no REMOVELISTENER message will be sent.
         }
      }

      return super.transport(sessionId, invocation, metadata, marshaller, unmarshaller);
   }


   protected Socket createSocket(String address, int port, int timeout) throws IOException
   {
      if (!isCallbackInvoker)
         return super.createSocket(address, port, timeout);

      if (timeout < 0)
      {
         timeout = getTimeout();
         if (timeout < 0)
            timeout = 0;
      }
      
      Set sockets = null;

      synchronized (listenerIdToSocketsMap)
      {
         sockets = (Set) listenerIdToSocketsMap.get(listenerId);
      }

      synchronized (controlLock)
      {
         controlOutputStream.write(Bisocket.CREATE_ORDINARY_SOCKET);
      }

      synchronized (sockets)
      {
         if (sockets.isEmpty())
         {
               try
               {
                  sockets.wait(timeout);
               }
               catch (InterruptedException e)
               {
                  log.warn("unexpected interrupt");
                  throw new InterruptedIOException("Attempt to create callback socket interrupted");
               }
         }

         if (sockets.isEmpty())
            throw new IOException("Timed out trying to create socket");

         Iterator it = sockets.iterator();
         Socket socket = (Socket) it.next();
         it.remove();
         log.debug("found socket: " + socket);
         return socket;
      }
   }


   void replaceControlSocket(Socket socket) throws IOException
   {
      synchronized (controlLock)
      {
         controlSocket = socket;
         controlOutputStream = controlSocket.getOutputStream();
         log.debug("replaced control socket");
      }

      if (pingTimerTask != null)
         pingTimerTask.cancel();

      pingTimerTask = new PingTimerTask(this);
      if (timer == null)
      {
         timer = new Timer(true);
      }
      timer.schedule(pingTimerTask, pingFrequency, pingFrequency);
   }


   InvokerLocator getSecondaryLocator() throws Throwable
   {
      InternalInvocation ii = new InternalInvocation(Bisocket.GET_SECONDARY_INVOKER_LOCATOR, null);
      InvocationRequest r = new InvocationRequest(null, null, ii, null, null, null);
      log.debug("getting secondary locator");
      Exception savedException = null;
      
      for (int i = 0; i < maxRetries; i++)
      {
         try
         {
            Object o = invoke(r);
            log.debug("secondary locator: " + o);
            secondaryLocator = (InvokerLocator) o;
            return secondaryLocator;
         }
         catch (Exception e)
         {
            savedException = e;
            log.info("unable to get secondary locator: trying again");
         }
      }
      
      throw savedException;
   }


   public InvokerLocator getCallbackLocator(Map metadata)
   {
      String transport = (String) metadata.get(Client.CALLBACK_SERVER_PROTOCOL);
      String host = (String) metadata.get(Client.CALLBACK_SERVER_HOST);
      String sPort = (String) metadata.get(Client.CALLBACK_SERVER_PORT);
      int port = -1;
      if (sPort != null)
      {
         try
         {
            port = Integer.parseInt(sPort);
         }
         catch (NumberFormatException e)
         {
            throw new RuntimeException("Can not set internal callback server port as configuration value (" + sPort + " is not a number.");
         }
      }

      return new InvokerLocator(transport, host, port, "callback", metadata);
   }


   static class PingTimerTask extends TimerTask
   {
      private Object controlLock;
      private OutputStream controlOutputStream;
      
      PingTimerTask(BisocketClientInvoker invoker)
      {
         controlLock = invoker.controlLock;
         controlOutputStream = invoker.controlOutputStream;
      }
      
      public void shutDown()
      {
         synchronized (controlLock)
         {
            controlOutputStream = null;
         }
         cancel();
      }

      public void run()
      {
         synchronized (controlLock)
         {
            try
            {
               controlOutputStream.write(Bisocket.PING);
            }
            catch (IOException e)
            {
               log.warn("Unable to send ping: shutting down PingTimerTask");
               shutDown();
            }
         }
      }
   }
}