import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colorbar import Colorbar
import spharpy
from spharpy.plot import balloon_wireframe


def plot_basis_functions(Y, sampling, n_max=2):
    fig = plt.figure(figsize=(8, 6))
    gs = plt.GridSpec(
        n_max+2, 2*n_max+1, height_ratios=np.r_[np.ones(n_max+1), 0.1])
    axs = []

    view_angle = (30, 30)

    for acn in range((n_max+1)**2):
        n, m = spharpy.spherical.acn_to_nm(acn)
        idx_m = (2*n_max + 1)//2 + m
        ax = plt.subplot(gs[n, idx_m], projection='3d')

        balloon = balloon_wireframe(
            sampling, Y[:, acn], cmap_encoding='phase', colorbar=False,
            ax=ax)
        ax.set_title('$Y_{' + str(n) + '}^{' + str(m) + '}(\\theta, \\phi)$')
        plt.axis('off')

        ax.view_init(*view_angle)
        axs.append(ax)


    ax = plt.subplot(gs[0, 0], projection='3d')
    ax.plot([0, 1], [0, 0], [0, 0], color='k')
    ax.plot([0, 0], [0, 1], [0, 0], color='k')
    ax.plot([0, 0], [0, 0], [0, 1], color='k')
    ax.text(0, 0, 1.1, 'z')
    ax.text(0, 1.1, 0, 'y')
    ax.text(1.3, -.05, 0, 'x')

    ax.set_box_aspect(np.ones(3))
    ax.view_init(*view_angle)
    plt.axis('off')

    cax = plt.subplot(gs[n_max+1, 1:-1])

    cnorm = plt.Normalize(0, 2*np.pi)
    cmappable = mpl.cm.ScalarMappable(cnorm, spharpy.plot.phase_twilight())
    cmappable.set_array(np.linspace(0, 2*np.pi, 128))

    cb = Colorbar(
        ax=cax, mappable=cmappable,
        orientation='horizontal', ticklocation='bottom')
    cb.set_label('Phase in rad')
    cb.set_ticks(np.linspace(0, 2*np.pi, 5))
    cb.set_ticklabels([r'$0$', r'$\pi/2$', r'$\pi$', r'$3\pi/2$', r'$2\pi$'])
    plt.tight_layout()

    return axs, gs