"""
Module defines the two basic distance measures *Trace Distance* and
*Quantum Infidelity* for density matrices. Moreover the maximum distance
over two systems can be determined.
"""
import numpy as np
#import numpy.linalg as la
__all__ = ['distance', 'infidelity', 'tracedist',
           'single_site_rho_max_distance', 'two_site_rho_max_distance']
[docs]def distance(StateA, StateB, dist='Infidelity'):
    """
    Calculate the distance between two quantum states. States can be
    either pure states or density matrices. Must be defined on a
    Hilbert space of equal size.
    **Arguments**
    StateA : 1d or 2d numpy array
        Represents the first quantum state as pure state or density
        matrix.
    StateB : 1d or 2d numpy array
        Represents the second quantum state as pure state or density
        matrix.
    dist : str, optional
        Define the measure used, e.g. ``Infidelity`` or ``Trace``
        distance.
        Default to ``Infidelity``.
    """
    if(dist == 'Infidelity'):
        return infidelity(StateA, StateB)
    elif(dist == 'Trace'):
        return tracedist(StateA, StateB) 
    else:
        raise Exception("Unknown string identifier for distance.") 
   
[docs]def single_site_rho_max_distance(Out1, Out2, ll, dist='Infidelity'):
    """
    Calculate the maximal distance (infidelity) between the single-site
    density matrices of two outputs.
    **Arguments**
    Out1 : dictionary
        containing the output of the first simulation.
    Out2 : dicitionary
        containing the output of the second simulation.
    ll : int
        system size
    dist : str, optional
        Choose which distance should be used. Either 'Infidelity' (default)
        of 'TraceDistance'.
    """
    if(dist == 'Infidelity'):
        distancef = infidelity_rho_rho
    else:
        distancef = tracedist_rho_rho
    err = 0.0
    for ii in range(1, ll + 1):
        infid = distancef(Out1['rho'][str(ii)], Out2['rho'][str(ii)])
        err = max(err, np.real(infid))
    return err 
[docs]def two_site_rho_max_distance(Out1, Out2, ll, dist='Infidelity'):
    """
    Calculate the maximal distance (infidelity) between the two-site density
    matrices of two outputs.
    **Arguments**
    Out1 : dictionary
        containing the output of the first simulation.
    Out2 : dicitionary
        containing the output of the second simulation.
    ll : int
        system size
    dist : str, optional
        Choose which distance should be used. Either 'Infidelity' (default)
        of 'TraceDistance'.
    """
    if(dist == 'Infidelity'):
        distancef = infidelity_rho_rho
    else:
        distancef = tracedist_rho_rho
    err = 0.0
    for ii in range(1, ll):
        for jj in range(ii + 1, ll + 1):
            key = str(ii) + '_' + str(jj)
            infid = distancef(Out1['rho'][key], Out2['rho'][key])
            err = max(err, np.real(infid))
    return err 
# ------------------------------------------------------------------------------
# QUANTUM INFIDELITY
# ------------------------------------------------------------------------------
[docs]def infidelity(StateA, StateB):
    """
    Calculate the infidelity :math:`I = 1 - \sqrt{\sqrt{rho} sigma \sqrt{rho}}`
    **Arguments**
    StateA : 1d or 2d numpy array
        Represents the first quantum state as pure state or density
        matrix.
    StateB : 1d or 2d numpy array
        Represents the second quantum state as pure state or density
        matrix.
    """
    purea = ispure(StateA)
    pureb = ispure(StateB)
    if(purea and pureb):
        return infidelity_psi_psi(StateA, StateB)
    elif(purea):
        return infidelity_psi_rho(StateA, StateB)
    elif(pureb):
        return infidelity_psi_rho(StateB, StateA)
    else:
        return infidelity_rho_rho(StateA, StateB) 
[docs]def infidelity_psi_psi(psi, phi):
    """
    Calculate the infidelity of two pure states, which simplifies to
    :math:`I = 1 - | \\langle \psi | \phi \\rangle`.
    **Arguments**
    psi : 1d numpy array
        Represents the first quantum state for the distance measurement.
    phi : 1d numpy array
        Represents the second quantum state for the distance measurement.
    """
    return 1 - np.abs(psi.conj().dot(phi)) 
[docs]def infidelity_psi_rho(psi, rho):
    """
    Calculate the infidelity of a pure state and a density matrix, which
    simplifies to
    :math:`I = 1 - \\sqrt{\\langle \psi | \\rho | \psi \\rangle}`.
    **Arguments**
    psi : 1d numpy array
        Represents the first quantum state for the distance measurement,
        which is a pure state.
    rho : 2d numpy array
        Represents the second quantum state for the distance measurement,
        which is the density matrix.
    """
    return np.real(1 - np.sqrt(psi.conj().dot(rho).dot(psi))) 
[docs]def infidelity_rho_rho(rho, sigma):
    """
    Calculate the infidelity :math:`I = 1 - \sqrt{\sqrt{rho} sigma \sqrt{rho}}`
    **Arguments**
    rho : 2d numpy array
        First density matrix.
    sigma : 2d numpy array
        Second density matrix.
    """
    tmp = sqrtm(rho)
    tmp = np.dot(tmp, np.dot(sigma, tmp))
    return np.real(1 - np.trace(sqrtm(tmp))) 
# ------------------------------------------------------------------------------
# TRACE DISTANCE
# ------------------------------------------------------------------------------
[docs]def tracedist(StateA, StateB):
    """
    Calculate the trace distance :math:`D = 0.5 Tr | rho - sigma |` with
    :math:`|A| = \sqrt{A^{\dagger} A}`.
    **Arguments**
    StateA : 1d or 2d numpy array
        Represents the first quantum state as pure state or density
        matrix.
    StateB : 1d or 2d numpy array
        Represents the second quantum state as pure state or density
        matrix.
    """
    purea = ispure(StateA)
    pureb = ispure(StateB)
    if(purea and pureb):
        return tracedist_psi_psi(StateA, StateB)
    elif(purea):
        return tracedist_psi_rho(StateA, StateB)
    elif(pureb):
        return tracedist_psi_rho(StateB, StateA)
    else:
        return tracedist_rho_rho(StateA, StateB) 
[docs]def tracedist_psi_psi(psi, phi):
    """
    Calculate the trace distance for two pure states, which simplifies to
    :math:`\\sqrt{1 - |\\langle \psi | \phi \\rangle|^2}`
    **Arguments**
    psi : 1d numpy array
        Represents the first quantum state for the distance measurement.
    phi : 1d numpy array
        Represents the second quantum state for the distance measurement.
    """
    return np.sqrt(1 - np.abs(psi.conj().dot(phi))**2) 
[docs]def tracedist_psi_rho(psi, rho):
    """
    Calculate the trace distance for a pure state and a density matrix, that is
    :math:`\\frac{1}{2} \sum_{i} \Lambda_{i}` where :math:`\Lambda_{i}` are the
    singular values of :math:`\\rho - | \psi \\rangle \langle \psi |`.
    **Arguments**
    psi : 1d numpy array
        Represents the first quantum state for the distance measurement,
        which is a pure state.
    rho : 2d numpy array
        Represents the second quantum state for the distance measurement,
        which is the density matrix.
    """
    tmp = rho - np.outer(psi.conj(), psi)
    return 0.5 * np.sum(la.svd(tmp, compute_uv=False)) 
[docs]def tracedist_rho_rho(rho, sigma):
    """
    Calculate the trace distance :math:`D = 0.5 Tr | rho - sigma |` with
    :math:`|A| = \sqrt{A^{\dagger} A}`.
    **Arguments**
    rho : 2d numpy array
        First density matrix.
    sigma : 2d numpy array
        Second density matrix.
    """
    tmp = rho - sigma
    return 0.5 * np.sum(la.svd(tmp, compute_uv=False)) 
# ------------------------------------------------------------------------------
# AUXILIARY FUNCTIONS
# ------------------------------------------------------------------------------
[docs]def sqrtm(Math):
    """
    Calculate the matrix square root of a hermitian matrix
    **Arguments**
    Math : 2d numpy array (square matrix)
        Hermitian matrix for calculating :math:`\sqrt{Math}` of the matrix,
        not element-wise.
    """
    vals, vecs = la.eigh(Math)
    vals[vals < 1e-20] = 0.0
    vals = np.sqrt(vals)
    return np.dot(vecs, np.dot(np.diag(vals), np.conj(np.transpose(vecs)))) 
[docs]def ispure(State):
    """
    Inquire if state is a pure state. 1d numpy array are pure states, 2d
    numpy arrays are density matrices and not pure.
    **Arguments**
    State : 1d or 2d numpy array
        Representing the state.
    """
    if(len(State.shape) == 1):
        pure = True
    elif(len(State.shape) == 2):
        pure = False
    else:
        raise Exception("Passed rank-3 tensor as quantum state.")
    return pure