"""
The general purpose tools to manipulate figure (scenes) and the pipeline
with the mlab interface.
"""

# Author: Prabhu Ramachandran <prabhu_r@users.sf.net>
# Copyright (c) 2007, Enthought, Inc.
# License: BSD Style.

# Standard library imports.
import numpy

# Enthought library imports.
from enthought.envisage import get_application
from enthought.tvtk.api import tvtk

# MayaVi related imports.
from enthought.mayavi.services import IMAYAVI
from enthought.mayavi.sources.vtk_data_source import VTKDataSource
from enthought.mayavi.app import Mayavi
from enthought.mayavi.core.module_manager import ModuleManager
from enthought.mayavi.sources.array_source import ArraySource
from enthought.mayavi.core.source import Source


######################################################################
# Application and mayavi instances.

application = get_application()
mayavi = None
if application is not None:
    mayavi = application.get_service(IMAYAVI)

######################################################################
# Utility functions.
def _make_default_figure():
    """Checks to see if a valid mayavi instance is running.  If not
    creates a new one.
    Also check for that a scene is open. If not create a new one.
    """
    global application, mayavi
    application = get_application()
    if application is None:
        fig = figure()
        application = get_application()
    mayavi = application.get_service(IMAYAVI)
    if ( mayavi is None 
                or application.stopped is not None
                or mayavi.engine.current_scene is None ) :
        fig = figure()
    return mayavi

def _add_data(tvtk_data, name=''):
    """Add a TVTK data object `tvtk_data` to the mayavi pipleine.
    Give the object a name of `name`.
    """
    if isinstance(tvtk_data, tvtk.Object):
        d = VTKDataSource()
        d.data = tvtk_data
    elif isinstance(tvtk_data, Source):
        d = tvtk_data
    else:
        raise TypeError, \
              "first argument should be either a TVTK object"\
              " or a mayavi source"

    if len(name) > 0:
        d.name = name
    _make_default_figure().add_source(d)
    return d

def _traverse(node):
    """Traverse a tree accessing the nodes children attribute.
    """
    try:
        for leaf in node.children:
            for leaflet in _traverse(leaf):
                yield leaflet
    except AttributeError:
        pass
    yield node

def _find_data(object):
    """Goes up the vtk pipeline to find the data sources of a given
    object.
    """
    if isinstance(object, ModuleManager):
        inputs = [object.source]
    elif hasattr(object, 'module_manager'):
        inputs = [object.module_manager.source]
    elif ( hasattr(object, 'data') or isinstance(object, ArraySource) 
                or hasattr(object, 'inputs')):
        inputs = [object]
    else:
        raise TypeError, 'Cannot find data source for given object'
    data_sources = []
    try:
        while True:
            input = inputs.pop()
            if hasattr(input, 'inputs'):
                inputs += input.inputs
            elif hasattr(input, 'image_data'):
                data_sources.append(input.image_data)
            else:
                data_sources.append(input.data)
    except IndexError:
        pass
    return data_sources

def _has_scalar_data(object):
    """Tests if an object has scalar data.
    """
    data_sources = _find_data(object)
    for source in data_sources:
        if source.point_data.scalars is not None:
            return True
        elif source.cell_data.scalars is not None:
            return True
    return False

def _has_vector_data(object):
    """Tests if an object has vector data.
    """
    data_sources = _find_data(object)
    for source in data_sources:
        if source.point_data.vectors is not None:
            return True
        elif source.cell_data.vectors is not None:
            return True
    return False

def _has_tensor_data(object):
    """Tests if an object has tensor data.
    """
    data_sources = _find_data(object)
    for source in data_sources:
        if source.point_data.tensors is not None:
            return True
        elif source.cell_data.tensors is not None:
            return True
    return False

def _find_module_manager(object=None, data_type=None):
    """If an object is specified, returns its module_manager, elsewhere finds
    the first module_manager in the scene.
    """
    if object is None:
        for object in _traverse(gcf()):
            if isinstance(object, ModuleManager):
                if ((data_type == 'scalar' and not _has_scalar_data(object))    
                  or (data_type == 'vector' and not _has_vector_data(object))
                  or (data_type == 'tensor' and not _has_tensor_data(object))):
                    continue
                return object
    else:
        if hasattr(object, 'module_manager'):
            if ((data_type == 'scalar' and _has_scalar_data(object))
               or (data_type == 'vector' and _has_vector_data(object))
               or (data_type == 'tensor' and _has_tensor_data(object))
                or data_type is None):
                return object.module_manager
            else:
                print("This object has no %s data" % data_type)
        else:
            print("This object has no color map")
    return None

def _orient_colorbar(colorbar, orientation):
    """Orients the given colorbar (make it horizontal or vertical).
    """
    if orientation == "vertical":
        colorbar.orientation = "vertical"
        colorbar.width = 0.1
        colorbar.height = 0.8
        colorbar.position = (0.01, 0.15)
    elif orientation == "horizontal":
        colorbar.orientation = "horizontal"
        colorbar.width = 0.8
        colorbar.height = 0.17
        colorbar.position = (0.1, 0.01)
    else:
        print "Unknown orientation"
    draw()

def _typical_distance(data_obj):
    """ Returns a typical distance in a cloud of points.
        This is done by taking the size of the bounding box, and dividing it
        by the cubic root of the number of points.
    """
    x_min, x_max, y_min, y_max, z_min, z_max = data_obj.bounds
    distance = numpy.sqrt(((x_max-x_min)**2 + (y_max-y_min)**2 +
                           (z_max-z_min)**2)/(4*
                           data_obj.number_of_points**(0.33)))
    if distance == 0:
        return 1
    else:
        return 0.4*distance

def _set_extent(module, extents):
    if numpy.all(extents == 0.):
        # That the default setting.
        return
    xmin, xmax, ymin, ymax, zmin, zmax = extents
    xo = 0.5*(xmax + xmin)
    yo = 0.5*(ymax + ymin)
    zo = 0.5*(zmax + zmin)
    extentx = 0.5*(xmax - xmin)
    extenty = 0.5*(ymax - ymin)
    extentz = 0.5*(zmax - zmin)
    # Now the actual bounds.
    xmin, xmax, ymin, ymax, zmin, zmax = module.actor.actor.bounds
    # Scale the object
    boundsx = 0.5*(xmax - xmin)
    boundsy = 0.5*(ymax - ymin)
    boundsz = 0.5*(zmax - zmin)
    xs, ys, zs = module.actor.actor.scale
    module.actor.actor.scale = (xs*extentx/boundsx,
                                        ys*extenty/boundsy,
                                        zs*extentz/boundsz)
    ## Remeasure the bounds
    xmin, xmax, ymin, ymax, zmin, zmax = module.actor.actor.bounds
    xcenter = 0.5*(xmax + xmin)
    ycenter = 0.5*(ymax + ymin)
    zcenter = 0.5*(zmax + zmin)         
    # Center the object                 
    module.actor.actor.origin = (0.,  0.,  0.)
    xpos, ypos, zpos = module.actor.actor.position
    module.actor.actor.position = (xpos + xo -xcenter, ypos + yo - ycenter,
                                            zpos + zo -zcenter)

def get_mayavi():
    """ Returns a running instance of the mayavi script engine. If none is 
        running, start Mayavi.
    """
    global application, mayavi
    application = get_application()
    if application is not None:
        mayavi = application.get_service(IMAYAVI)
        if mayavi is not None and application.stopped is None:
            return mayavi
    m = Mayavi()
    m.main()
    mayavi = m.script
    return mayavi

def figure():
    """If you are running from IPython this will start up mayavi for
    you!  This returns the current running MayaVi script instance.
    """
    mayavi = get_mayavi() 
    mayavi.new_scene()
    view(40, 50)
    return mayavi.engine.current_scene

def gcf():
    """Return a handle to the current figure.
    """
    mayavi = get_mayavi()
    scene = mayavi.engine.current_scene
    if scene is None:
        mayavi.new_scene()
        scene = mayavi.engine.current_scene
        view(40, 50)
    return scene

def clf():
    """Clear the current figure.
    """
    try:
        scene = gcf()
        scene.scene.disable_render = True
        scene.children[:] = []
        scene.scene.disable_render = False
    except AttributeError:
        pass

def draw():
    gcf().render()

def savefig(filename, size=None, **kwargs):
    """ Save the current scene.
        The output format are deduced by the extension to filename.
        Possibilities are png, jpg, bmp, tiff, ps, eps, pdf, rib (renderman),
        oogl (geomview), iv (OpenInventor), vrml, obj (wavefront)

        If an additional size (2-tuple) argument is passed the window
        is resized to the specified size in order to produce a
        suitably sized output image.  Please note that when the window
        is resized, the window may be obscured by other widgets and
        the camera zoom is not reset which is likely to produce an
        image that does not reflect what is seen on screen.

        Any extra keyword arguments are passed along to the respective
        image format's save method.
    """
    gcf().scene.save(filename, size=size, **kwargs)

def _xyz2rthetaphi(vec):
    """ Returns an r, theta, phi vector for an yxz one (! angles in
    degrees, x and y swapped compared to cylindrical coords)"""
    pi = numpy.pi
    r = numpy.sqrt(numpy.square(vec).sum())
    vec = vec / r
    theta = numpy.arccos(vec[2])*180/pi
    xy = vec[0:2] / numpy.sqrt(numpy.square(vec[0:2]).sum())
    phi = numpy.arccos(xy[1])*180/pi
    if numpy.isnan(phi):
        phi=0
    return numpy.array([r, theta, phi])

def _rthetaphi2xyz(vec):
    """ Returns an xyz vector from an r, theta, phi one (! angles in
    degrees, x and y swapped compared to cylindrical coords)"""
    r, theta, phi = vec
    pi = numpy.pi
    cos = numpy.cos
    sin = numpy.sin
    return r*numpy.array([ sin(theta*pi/180.)*sin(phi*pi/180.),
                        sin(theta*pi/180.)*cos(phi*pi/180.),
                        cos(theta*pi/180.)])


def _constraint_thetaphi(theta, phi):
    """ Constraint theta, phi to [0, 180] x [-180, 180] """
    n_theta = numpy.floor(theta/180.)
    theta = abs((theta +180) % 360 - 180)
    phi = ((n_theta*180 + phi + 180) % 360) - 180
    return theta, phi

def guess_roll1(phi, theta, phi_orig):
    """ Magic code to get the roll angle right. That was real hard to
        figure out ! """
    # theta, phi in [0, 180] x [-180, 180]
    sign = numpy.sign(theta - 90)
    equator_factor = numpy.exp(-(theta -90)**2/400.)*90
    if sign == 0:
        sign = -1
    if abs(theta - 90) < 3:
        if phi == 0:
            return 90 + sign*90
        elif phi == 180 or phi == -180:
            return 90 + sign*90
        else:
            return numpy.sign(phi)*sign*90
    elif abs(phi %180) < 3:
        # FIXME: This hould probably be transformed in an
        # "azimuth_factor"
        return 0        
    signp = numpy.sign(phi_orig)
    if signp == 0:
        signp=1
    return 90*(1+sign)- signp*sign*( phi_orig/(1 +equator_factor) 
                                                    + equator_factor)

def guess_roll(phi, theta, phi_orig):
    """ Magic code to get the roll angle right. That was real hard to
        figure out ! """
    # theta, phi in [0, 180] x [-180, 180]
    sign = numpy.sign(theta - 90)
    equator_factor = (numpy.cos(theta-90))**2
    if sign == 0:
        sign = -1
    if abs(theta - 90) < 3:
        if phi == 0:
            return 90 + sign*90
        elif phi == 180 or phi == -180:
            return 90 + sign*90
        else:
            return numpy.sign(phi)*sign*90
    elif abs(phi %180) < 3:
        # FIXME: This hould probably be transformed in an
        # "azimuth_factor"
        return 0        
    signp = numpy.sign(phi_orig)
    if signp == 0:
        signp=1
    return 90*(1+sign)- signp*sign*( 90*numpy.sin(phi_orig)*(1 - equator_factor) 
                                                    + 90*equator_factor)

def guess_roll(phi, theta, phi_orig):
    """ Magic code to get the roll angle right. That was real hard to
        figure out ! """
    # theta, phi in [0, 180] x [-180, 180]
    print "theta", theta
    print "phi", phi
    print "phi_orig", phi_orig
    if theta==0:
        return phi
    elif theta==90:
        return -numpy.sign(phi)*90
    pi = numpy.pi
    cos = lambda t: numpy.cos(t/180.*pi)
    sin = lambda t: numpy.sin(t/180.*pi)
    theta = float(theta)
    t = theta - 90
    p = phi - 90
    phi = phi_orig
    print "(t/theta)", t/theta
    roll = ((90/theta)**2*90
                *numpy.tanh(phi*(theta/(90*t))**2))
    print 'roll', roll
    return roll

def guess_roll(phi, theta):
    """ Tries to get the roll angle right to make the picture look good.
        See also: roll"""
    # This is really magic fudge ! Before modifying this, spend a long
    # time understanding the problem. There must be a rigorous way of
    # doing this, but I couldn't find any, and the net didn't help.
    #
    # The functions here where found by considering special lines
    # (along the equator, on the principal meridian, near the pole),
    # finding the right value for "roll" on these lines, and finding an
    # extrapolation on the complete globe. Of course there is a singular
    # point. This is unovaidable (hairy ball theorem)
    # 
    # theta, phi in [0, 180] x [-180, 180]
    if theta==0:
        return phi
    elif theta==90:
        if phi in (-180, 180):
            return 0
        return -numpy.sign(phi)*90
    pi = numpy.pi
    remainer = 0.
    if theta > 90:
        theta = 180. - theta
        remainer = 180.
    if phi > 90:
        phi = 180 -phi
    elif phi < -90:
        phi = -180 -phi
    rr = lambda p, t: numpy.sign(p)*90*pow(abs(p)/90., (90-t)/90.)
    rr = lambda p, t: numpy.sign(p)*90*pow(abs(p)/90., numpy.sqrt((90-t)/90.))
    roll = remainer -rr(phi, theta)
    return roll

def view(azimuth=None, elevation=None, distance=None, focalpoint=None):
    """ Sets the view point for the camera. 
    
    azimuth: angle in the horizontal plane
    elevation: elevation angle of the camera relative to the vertical
    If some parameters are not passed, they are left unchanged. The
    function tries to guess the roll angle appropriate for the view.
    see also: roll."""
    # XXX: It might be more sensible to have elevation = 90+theta
    # Currently theta = - elevation
    cam = gcf().scene._renderer.active_camera
    if focalpoint is not None:
        cam.focal_point = focalpoint
    vec = cam.position - cam.focal_point
    r, theta, phi = _xyz2rthetaphi(vec)
    if azimuth is not None:
        phi = azimuth
    if elevation is not None:
        theta = -elevation
    theta, phi = _constraint_thetaphi(theta, phi)
    if distance is not None:
        r = distance
    if not (      azimuth is None
             and  elevation is None
             and  distance is None
             and  focalpoint is None ):
        vec = _rthetaphi2xyz([r, theta, phi])
        cam.position = cam.focal_point + numpy.array(vec)
        cam.orthogonalize_view_up()
        roll = guess_roll(phi, theta)
        cam.set_roll(roll)
        # FIXME: vtk knows how to calculate that. Need to find out.
        cam.clipping_range = r*numpy.array([0.2, 1.5])
        draw()
    return phi, -theta, r

def roll(roll=None):
    """ Sets or returns the absolute roll angle of the camera """
    cam = gcf().scene._renderer.active_camera
    if roll is not None:
        cam.set_roll(roll)
        draw()
    return cam.get_roll()

