#!/usr/local/anaconda3/bin/python

# Hodge podge of calls to libcsp routines.  Really should be better
# structured!

# Mike Markowski
# mike.ab3ap@gmail.com
# Mar 2021

from numpy import fft, log10
import mksig
import libcsp as csp
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import time

def dB_to_W(s_dB):
    return 10**(s_dB/10)

def W_to_dB(s_W):
    return 10*log10(s_W)

def main():

    fc_norm = 0.05   # Desired carrier frequency (normalized, fc_Hz/fs_Hz).
    fftSizeP = 128   # FFT size for periodogram.
    N_psd = 32768    # Number of frequencies to use in PSD estimate.
    numBits = 4000   # Desired number of bits in generated signal.
    sig_dBW = 0.0    # Signal power.
    pulseWidth = 164 # Width in array elements of unit area smoothing pulse.
    snr_dB = 10.0    # Noise spectral density (average noise power).
    T_bit = 10       # Samples per symbol, 1/T_bit is the bit rate.

    # Debugging: make FFT bins align perfectly with cycle frequencies.
#   T_bit = 8
#   fc_norm = 1/32

    # BPSK signal over AWGN channel.
    c = mksig.configDefault()
    c['bps'] = 1/T_bit
    c['fc_Hz'] = fc_norm
    c['nSym'] = numBits
    c['snr_dB'] = snr_dB
    sigRx = mksig.mkSig(c)
    unmixed = mksig.demodBpsk(sigRx, fc_norm, T_bit) # Demod for plotting.

#   sigRx = csp.zeroPad(sigRx, 2)
#   N_psd *= 2

    res1 = True
    res2 = True
    res3 = True
    res4 = True

    print('Chad Spooner\'s CSP blog: https://cyclostationary.blog')
    print('Beginners: https://cyclostationary.blog/2019/07/14/for-the-beginner-at-csp/')

    #
    #   1 .   B P S K   a n d   P e r i o d o g r a m
    #

    if res1:
        print('\nCreating a Simple CS Signal: Rectangular-Pulse BPSK')
        print('https://cyclostationary.blog/2015/09/28/creating-a-simple-cs-signal-rectangular-pulse-bpsk/')
        zPad = 4
        psd = csp.scfFsm(sigRx[:N_psd], alpha=0, padFactor=zPad) # Periodogram.
        psd = csp.smooth(psd, zPad*pulseWidth) # Smoothed periodogram.
        plotResults1(unmixed, psd)

    #
    #   2 .   S C F ,   F S M   a n d   T S M
    #

    if res2:
        print('\nCSP Estimators: The Frequency-Smoothing Method')
        print('https://cyclostationary.blog/2015/11/20/csp-estimators-the-frequency-smoothing-method/')
        print('CSP Estimators: The Time Smoothing Method')
        print('https://cyclostationary.blog/2016/03/22/csp-estimators-the-strip-spectral-correlation-analyzer/')
        # Check power of PSD against known sum of noise and signal powers.
        measPwr_W = psd.sum()/N_psd
        nPwr_W = dB_to_W(sig_dBW - snr_dB) # N = S - snr
        sPwr_W = dB_to_W(sig_dBW)
        knownPwr_W = sPwr_W + nPwr_W
        print('PSD-measured power is %.5e, known total power is %.5e'
            % (abs(measPwr_W), knownPwr_W)) # measPwr_W.imag is 0j.

        # Frequency smoothed, non-conjugate cyclic periodogram.
        scfFsmNc = [] # Spectral correlation function, non-conjugate.
        zPad = 4
        for i in range(0,5):
            alpha = i/T_bit
            scf = csp.scfFsm(sigRx[:N_psd], alpha, padFactor=zPad)
            scf = csp.smooth(scf, zPad*pulseWidth)
            scfFsmNc.append(scf)

        # Frequency smoothed, conjugate cyclic periodogram.
        scfFsmC = [] # Spectral correlation function, conjugate.
        s = sigRx[:N_psd]
        for i in range(-1,2):
            alpha = 2*fc_norm + i/T_bit
            scf = csp.scfFsm(s, alpha, conj=True, padFactor=zPad)
            scf = csp.smooth(scf, zPad*pulseWidth)
            scfFsmC.append(scf)

        # Time smoothed, non-conjugate cyclic periodogram.
        scfTsmNc = []
        for i in range(0,5):
            alpha = i/T_bit
            scf = csp.scfTsm(sigRx[:N_psd], 256, alpha)
            scfTsmNc.append(scf)

        # Time smoothed, conjugate cyclic periodogram.
        scfTsmC = [] # Spectral correlation function, conjugate.
        s = sigRx[:N_psd]
        for i in range(-1,2):
            alpha = 2*fc_norm + i/T_bit
            scf = csp.scfTsm(s, 256, alpha, conj=True)
            scfTsmC.append(scf)
        plotResults2(scfFsmNc, scfFsmC, scfTsmNc, scfTsmC)

    #
    #   3 .   S p e c t r a l   C o h e r a n c e
    #

    if res3:
        print('\nThe Spectral Coherence Function')
        print('https://cyclostationary.blog/2016/01/08/the-spectral-coherence-function/')
        # Spectral coherance function, non-conjugate.
        N_psd = 65536 # To match spectral coherance blog.
        s = sigRx[:N_psd]
        scnc = [] # Spectral coherance for given alpha.
        alphas = np.linspace(0.1, 0.9, 9)
#       alphas = np.linspace(0.0, 0.9, 10)
        for alpha in alphas:
            scnc.append(csp.sc(s, alpha, padFactor=1))

        # Spectral coherance function, conjugate.
        scc = [] # Spectral coherance for given alpha.
        alphasC = np.linspace(-0.9, 0.9, 19)
#       alphasC = np.linspace(-1.0, 1.0, 21)
        for alpha in alphasC:
            scc.append(csp.sc(s, alpha, conj=True, padFactor=1))
        plotResults3(alphas, scnc, alphasC, scc)

    #
    #   4 .   S S C A
    #

    if res4:
        # Settings.
        nMax = 600
        N = 65536 # Samples to process per channelizer strip.
        Np = 64 # Number of strips (N*Np points to process).
        numBits = int(np.ceil((N + Np)/10))

        c = mksig.configDefault()
        c['bps'] = 1/T_bit
        c['fc_Hz'] = fc_norm
        c['nSym'] = numBits
        c['snr_dB'] = snr_dB
        sigRx = mksig.mkSig(c)

        if True: # SSCA
            print('\nCSP Estimators: The Strip Spectral Correlation Analyzer')
            print('https://cyclostationary.blog/2016/03/22/csp-estimators-the-strip-spectral-correlation-analyzer/')
            sortby = 'scf'
            x = sigRx[:N+Np]
            # Data generation.
            t0=time.process_time()
#           sscaNc, df, da, dtdf = csp.ssca(x, Np, N, conj=False)
            fNss, alphaNss, sscaNc = csp.ssca(x, Np, N, conj=False)
            t1=time.process_time()
            print('SSCA NC: %.2f ms' % (1e3*(t1-t0)))
            t0=time.process_time()
#           sscaC,  df, da, dtdf = csp.ssca(x, Np, N, conj=True)
            fCss, alphaCss, sscaC = csp.ssca(x, Np, N, conj=True)
            t1=time.process_time()
            print('SSCA C: %.2f ms' % (1e3*(t1-t0)))
            scNc    = csp.sscaSc(x, sscaNc, conj=False)
            scC     = csp.sscaSc(x, sscaC,  conj=True)
            quadsNc = csp.filter(fNss, alphaNss, scNc, sscaNc,
                threshold=0.05, top=nMax, sortby=sortby)
            quadsC  = csp.filter(fCss, alphaCss, scC, sscaC,
                threshold=0.05, top=nMax, sortby=sortby)

            if False:
                print('SSCA results normalized resolution:')
                print('  delta f:     %6.2f mHz' % (1e3*df))
                print('  delta alpha: %6.2f uHz' % (1e6*da))
                print('  dt df:       %.2f >> 1?' % dtdf)

            # Data plotting.
            plotResults4(quadsNc, quadsC, 'SSCA')

            sortby = 'sc'
            quadsNc = csp.filter(fNss, alphaNss, scNc, sscaNc,
                threshold=0.1, top=nMax, sortby=sortby)
            quadsC  = csp.filter(fCss, alphaCss, scC, sscaC,
            threshold=0.1, top=nMax,
                sortby=sortby)
            plotSc(quadsNc, quadsC, 'SSCA', sc=True)

        if True: # FAM
            print('\nCSP Estimators: The FFT Accumulation Method')
            print('https://cyclostationary.blog/2018/06/01/csp-estimators-the-fft-accumulation-method/')
            # Data generation.
            sortby = 'scf'
            thresh = 0.2
            L = 8
            x = sigRx[:N]
            t0=time.process_time()
            fNfam, alphaNfam, famNc = csp.fam(x, L, conj=False)
            t1=time.process_time()
            print('FAM NC: %.2f ms' % (1e3*(t1-t0)))
            t0=time.process_time()
            fCfam, alphaCfam, famC = csp.fam(x, L, conj=True)
            t1=time.process_time()
            print('FAM C: %.2f ms' % (1e3*(t1-t0)))
            scNcF   = csp.famSc(x, fNfam, alphaNfam, famNc, conj=False)
            scCF    = csp.famSc(x, fCfam, alphaCfam, famC,  conj=True)
            quadsNc = csp.filter(fNfam, alphaNfam, scNcF, famNc,
                threshold=thresh, top=nMax, sortby=sortby)
            quadsC  = csp.filter(fCfam, alphaCfam, scCF, famC,
                threshold=thresh, top=nMax, sortby=sortby)

            # Data plotting.
            plotResults4(quadsNc, quadsC, 'FAM')

            sortby = 'sc'
            quadsNc = csp.filter(fNfam, alphaNfam, scNcF, famNc,
                threshold=0.3, top=nMax, sortby=sortby)
            quadsC  = csp.filter(fCfam, alphaCfam, scCF, famC,
                threshold=0.3, top=nMax, sortby=sortby)
            plotSc(quadsNc, quadsC, 'FAM', sc=True)

        #
        #   P S D s   f r o m   F S M ,   F A M   a n d   S S C A
        #
        quadsF = csp.filter(fNfam, alphaNfam, scNcF, famNc,
            threshold=thresh, top=nMax, alpha0=0)            # FAM PSD.
        quadsS = csp.filter(fNss, alphaNss, scNc, sscaNc, threshold=0.1,
            top=nMax, alpha0=0)                              # SSCA PSD.
#       psd = csp.scfFsm(x, alpha=0)
#       psd = csp.smooth(psd, pulseWidth)                    # FSM PSD.
        psd = csp.scfTsm(x, 64, 0, taper=False)
        psd = np.abs(psd)
        scfF = quadsF[:,2]
#       scfF = csp.smooth(scfF, pulseWidth) # XXX Shouldn't be needed.
        scfS = quadsS[:,2]
#       print('len(fam) = %d, len(ssca) = %d' % (len(scfF), len(scfS)))
        fF = fft.fftshift(fft.fftfreq(len(scfF)))
        fS = fft.fftshift(fft.fftfreq(len(scfS)))
        fP = fft.fftshift(fft.fftfreq(len(psd)))
#       plotData('psd.png', [fF, fS], [scfF, scfS], 'PSD Comparison',
#           'Freq, norm (Hz)', 'Ampl (linear)', ['FAM', 'SSCA'])
        plotData('psd.png', [fF, fS, fP], [scfF, scfS, psd], 'PSD Comparison',
            'Freq, norm (Hz)', 'Ampl (linear)', ['FAM', 'SSCA', 'TSM tapered'])

        if False: # Make plots to compare to Chad's.
            plotChad(x, sscaC, sscaNc, nMax)
#           plotChadFam(x, fC, alphaC, famC, fNc, alphaNc, famNc, nMax)

#       plotSurface(sscaNc, scNc)

#
#   P l o t t i n g   R o u t i n e s
#
#   Separated because they are not CSP routines.
#

def plotData(fname, data1, data2=None, title='', xlabel='', ylabel='', key=''):
    plt.grid(True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    for i in range(len(data2)):
        if data2 is None:
            print('Wha\'???')
            plt.plot(10*log10(data1[i]))
        else:
            plt.plot(data1[i], 10*log10(data2[i]))
    if key != '':
        plt.legend(key)
    #plt.savefig(fname, format='png')
    plt.show()

def plotChad(x, sscaC, sscaNc, nMax):
    N, Np = sscaC.shape
    sig = x[Np//2:N+Np//2]
    for conj in [True, False]:
        scfs = sscaC if conj else sscaNc
        conjStr = 'Conj' if conj else 'Non-conj'

        for M in [64, 128, 256, 512]:
            f, alpha, coh = csp.sscaSc(x, scfs, conj=conj, M=M)

            for alpha in [0, 0.1]:
                if conj and alpha == 0:
                    continue # Don't want these plots.
                quads = csp.filter(f, alpha, coh, scfs, threshold=0.1,
                    top=nMax, alpha0=alpha)
                print('afsc shape: %s' % str(quads.shape))
                if len(quads) == 0:
                    continue
                scf = quads[:,2]
                sc = quads[:,3]
                fS = fft.fftshift(fft.fftfreq(len(quads)))

                plt.rcParams['axes.grid'] = True
                f, ax = plt.subplots(2, 1, figsize=(7, 4))
#               ax[1].set_aspect(833/1516)
                plt.subplots_adjust(hspace=0.7)

                t = 'SSCA %s SCF and Coherence ' % conjStr
                t += r'for $\alpha$ = %.1f, ' % alpha
                t += r'N$^\prime$ = 64 and T$_{tsm}$ = %d' % M
                ax[0].set_title(t)
                ax[0].set_xlabel('Freq (Hz)')
                ax[0].set_ylabel('Magnitude')
                ax[0].plot(fS, scf, linewidth=0.75, label='scf')
                ax[0].plot(fS, sc, linewidth=0.75, label='sc')
                if alpha == 0:
                    fM = fft.fftshift(fft.fftfreq(M))
                    psdTsm = csp.scfTsm(sig, M, alpha=0).real
                    ax[0].plot(fM, psdTsm, linewidth=0.75, label='scf tsm')
                ax[0].legend()

                ax[1].set_title('Zoomed')
                ax[1].set_xlabel('Freq (Hz)')
                ax[1].set_ylabel('Magnitude')
                ax[1].plot(fS, scf, linewidth=0.75, label='scf')
                ax[1].plot(fS, sc, linewidth=0.75, label='sc')
                ymax = 1 if max(sc) <= 1.02 else 2
                ax[1].set_ylim(0, ymax)
                if alpha == 0:
                    fM = fft.fftshift(fft.fftfreq(M))
                    psdTsm = csp.scfTsm(sig, M, alpha=0).real
                    ax[1].plot(fM, psdTsm, linewidth=0.75, label='scf tsm')
                ax[1].legend()

                fname = 'c_' if conj else 'nc_'
                fname += str(alpha) + '_' + str(M) + '.png'
#               print('%s created.' % fname)
                plt.savefig(fname, format='png')

def plotChadFam(sig, famFC, famAlphaC, famC, famFNc, famAlphaNc, famNc, nMax):
    N, Np = famC.shape
    for conj in [True, False]:
        scfs = famC if conj else famNc
        famF = famFC if conj else famFNc
        famAlpha = famAlphaC if conj else famAlphaNc
        conjStr = 'Conj' if conj else 'Non-conj'
        for M in [64, 128, 256, 512]:
            coh = csp.famSc(sig, famF, famAlpha, scfs, conj=conj, M=M)
            for alpha in [0, 0.1]:
                if conj and alpha == 0:
                    continue # Don't want these plots.
                quads = csp.filter(famF, famAlpha, coh, scfs, threshold=0.1,
                    top=nMax, alpha0=alpha)
                print('afsc shape: %s' % str(quads.shape))
                if len(quads) == 0:
                    continue
                scf = quads[:,2]
                sc = quads[:,3]
                fS = fft.fftshift(fft.fftfreq(len(quads)))

                plt.rcParams['axes.grid'] = True
                f, ax = plt.subplots(2, 1, figsize=(7, 4))
#               ax[1].set_aspect(833/1516)
                plt.subplots_adjust(hspace=0.7)

                t = 'FAM %s SCF and Coherence ' % conjStr
                t += r'for $\alpha$ = %.1f, ' % alpha
                t += r'N$^\prime$ = 64 and T$_{tsm}$ = %d' % M
                ax[0].set_title(t)
                ax[0].set_xlabel('Freq (Hz)')
                ax[0].set_ylabel('Magnitude')
                ax[0].plot(fS, scf, linewidth=0.75, label='scf')
                ax[0].plot(fS, sc, linewidth=0.75, label='sc')
                if alpha == 0:
                    fM = fft.fftshift(fft.fftfreq(M))
                    psdTsm = csp.scfTsm(sig, M, alpha=0).real
                    ax[0].plot(fM, psdTsm, linewidth=0.75, label='scf tsm')
                ax[0].legend()

                ax[1].set_title('Zoomed')
                ax[1].set_xlabel('Freq (Hz)')
                ax[1].set_ylabel('Magnitude')
                ax[1].plot(fS, scf, linewidth=0.75, label='scf')
                ax[1].plot(fS, sc, linewidth=0.75, label='sc')
                ymax = 1 if max(sc) <= 1.02 else 2
                ax[1].set_ylim(0, ymax)
                if alpha == 0:
                    fM = fft.fftshift(fft.fftfreq(M))
                    psdTsm = csp.scfTsm(sig, M, alpha=0).real
                    ax[1].plot(fM, psdTsm, linewidth=0.75, label='scf tsm')
                ax[1].legend()

                fname = 'c_' if conj else 'nc_'
                fname += str(alpha) + '_' + str(M) + '.png'
#               print('%s created.' % fname)
                plt.savefig(fname, format='png')

def plotResults1(txSig, psdFsm):
    '''Plot time domain transmit signal and estimated power spectral density.

    Inputs:
      txSig (complex[]) : BPSK signal to be transmitted.
      psdFsm (complex[]) : estimated power spectral density.

    Output: none.
    '''

    plt.rcParams['axes.grid'] = True
    f, ax = plt.subplots(2, 1, figsize=(7,6))
    plt.subplots_adjust(hspace=0.7) # Room for lower subplot title.

    # Plot time domain waveform.
    ax[0].plot(txSig[:200].real, linewidth=0.75)
    title = 'Time-Domain Plot of Rectangular-Pulse BPSK '
    title += '($T_{bit}$ = 10, $f_c$ = 0)'
    ax[0].set_title(title)
    ax[0].set_xlabel('Sample Index')
    ax[0].set_ylabel('Signal Amplitude')

    # PSD estimate (FSM).
    title = 'Periodogram, Estimated Power Spectrum (FSM) '
    n = len(psdFsm)
    f = fft.fftshift(fft.fftfreq(n))
    ax[1].set_title(title)
    ax[1].plot(f, W_to_dB(abs(psdFsm)), linewidth=0.75)
    ax[1].set_xlabel('Frequency (Normalized)')
    ax[1].set_ylabel('PSD (dB)')
    ax[1].set_xlim(-0.4, 0.4)
    ax[1].set_ylim(-20, 11)

    plt.tight_layout() # Adjust spacing between plots.
    plt.show()

def plotResults2(scfFsmNcs, scfFsmCs, scfTsmNcs, scfTsmCs):
    '''Plot time domain transmit signal and estimated power spectral density.

    Inputs:
      txSig (complex[]) : BPSK signal to be transmitted.
      freqs (float[]) : frequencies of PSD estimate.
      psdEst (complex[]) : estimated power spectral density.

    Output: none.
    '''

    plt.rcParams['axes.grid'] = True
    f, ax = plt.subplots(2, 2, figsize=(14,7))
    plt.subplots_adjust(hspace=0.7) # Room for lower subplot title.
    lw = 0.75

    # PSD estimate (FSM), non-conjugate.
    title = 'Non-conjugate SCF (FSM) '
    ax[0, 0].set_title(title)
    n = len(scfFsmNcs[0])
    f = fft.fftshift(fft.fftfreq(n))
    key = ['PSD']
    i = 0
    for scf in scfFsmNcs:
        scf[abs(scf)<1e-200] = 1e-200 # Avoid log(10).
        ax[0, 0].plot(f, W_to_dB(abs(scf)), linewidth=lw)
        if i > 0:
            key.append('$\\alpha = %d/T_0$' % i)
        i += 1
    ax[0, 0].legend(key)
    ax[0, 0].set_xlabel('Frequency (Normalized)')
    ax[0, 0].set_ylabel('PSD (dB)')
    ax[0, 0].set_xlim(-0.4, 0.4)
    ax[0, 0].set_ylim(-20, 11)

    # PSD estimate (FSM), conjugate.
    title = 'Conjugate SCF (FSM) '
    ax[0, 1].set_title(title)
    n = len(scfFsmCs[0])
    f = fft.fftshift(fft.fftfreq(n))
    key = []
    i = -1
    for scf in scfFsmCs:
        scf[abs(scf)<1e-200] = 1e-200 # Avoid log(10).
        ax[0, 1].plot(f, W_to_dB(abs(scf)), linewidth=lw)
        if i < 0:
            key.append('$\\alpha = 2f_c %d/T_0$' % i)
        elif i == 0:
            key.append('$\\alpha = 2f_c$')
        else:
            key.append('$\\alpha = 2f_c + %d/T_0$' % i)
        i += 1
    ax[0, 1].legend(key)
    ax[0, 1].set_xlabel('Frequency (Normalized)')
    ax[0, 1].set_ylabel('PSD (dB)')
    ax[0, 1].set_xlim(-0.4, 0.4)
    ax[0, 1].set_ylim(-20, 11)

    # PSD estimate (TSM), non-conjugate.
    title = 'Non-conjugate SCF (TSM) '
    ax[1, 0].set_title(title)
    n = len(scfTsmNcs[0])
    f = fft.fftshift(fft.fftfreq(n))
    i = 0
    key = ['PSD']
    for scf in scfTsmNcs:
        scf[abs(scf)<1e-200] = 1e-200 # Avoid log(10).
        ax[1, 0].plot(f, W_to_dB(abs(scf)), linewidth=lw)
        if i > 0:
            key.append('$\\alpha = %d/T_0$' % i)
        i += 1
    ax[1, 0].legend(key)
    ax[1, 0].set_xlabel('Frequency (Normalized)')
    ax[1, 0].set_ylabel('PSD (dB)')
    ax[1, 0].set_xlim(-0.4, 0.4)
    ax[1, 0].set_ylim(-20, 11)

    # PSD estimate (FSM), conjugate.
    title = 'Conjugate SCF (TSM) '
    ax[1, 1].set_title(title)
    n = len(scfTsmCs[0])
    f = fft.fftshift(fft.fftfreq(n))
    key = []
    i = -1
    for scf in scfTsmCs:
        scf[abs(scf)<1e-200] = 1e-200 # Avoid log(10).
        ax[1, 1].plot(f, W_to_dB(abs(scf)), linewidth=lw)
        if i < 0:
            key.append('$\\alpha = 2f_c %d/T_0$' % i)
        elif i == 0:
            key.append('$\\alpha = 2f_c$')
        else:
            key.append('$\\alpha = 2f_c + %d/T_0$' % i)
        i += 1
    ax[1, 1].legend(key)
    ax[1, 1].set_xlabel('Frequency (Normalized)')
    ax[1, 1].set_ylabel('PSD (dB)')
    ax[1, 1].set_xlim(-0.4, 0.4)
    ax[1, 1].set_ylim(-20, 11)

    plt.tight_layout() # Adjust spacing between plots.
    plt.show()

def plotResults3(alphas, scs, alphasC, sccs, conj=False):
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib.collections import PolyCollection
    from matplotlib import colors as mcolors

    plt.rcParams['axes.grid'] = True
    fig = plt.figure(figsize=plt.figaspect(0.5))
    ax = fig.add_subplot(1, 2, 1, projection='3d')

    #
    #   S p e c t r a l   C o h e r a n c e ,   N o n - c o n j u g a t e
    #

    n = scs[0].size # All spectral coherances, scs[i], are same length.
    f = fft.fftshift(fft.fftfreq(n))
    # Pre- and post-pend adding room for polygon edges, below.
    f = np.concatenate(([f[0]-1/n], f, [f[-1]+1/n]))
    colors = plt.cm.jet(np.linspace(0, 1, alphas.size))

    verts = []
    for z in range(len(scs)):
        poly = np.concatenate(([0j], scs[z], [0j])) # 0 endpoints for poly.
        verts.append(list(zip(f, np.abs(poly))))
    poly = PolyCollection(verts, facecolors=colors)
    poly.set_edgecolor('#000000')

    poly.set_alpha(0.8)
    ax.add_collection3d(poly, zs=alphas, zdir='y')
    plt.yticks(ticks=[0, 0.2, 0.4, 0.6, 0.8, 1])

    ax.set_title('Spectral Coherance, Non-conjugate')
    ax.set_xlabel(r'$f\ /\ f_s$')
    ax.set_xlim3d(-0.5, 0.5)
    ax.set_ylabel(r'$\alpha$, cycle freq')
    ax.set_ylim3d(1, 0)
    ax.set_zlabel('Magnitude (linear)')
    ax.set_zlim3d(0, 1)

    #
    #   S p e c t r a l   C o h e r a n c e ,   C o n j u g a t e
    #

    n = sccs[0].size # All spectral coherances, sccs[i], are same length.
    f = fft.fftshift(fft.fftfreq(n))
    # Pre- and post-pend adding room for polygon edges, below.
    f = np.concatenate(([f[0]-1/n], f, [f[-1]+1/n]))
    colors = plt.cm.jet(np.linspace(0, 1, alphasC.size))

    ax = fig.add_subplot(1, 2, 2, projection='3d')

    verts = []
    for z in range(len(sccs)):
        # Add zeros for polygon edges.
        poly = np.concatenate(([0j], sccs[z], [0j]))
        verts.append(list(zip(f, np.abs(poly))))
    poly = PolyCollection(verts, facecolors=colors)
    poly.set_edgecolor('#000000')

    poly.set_alpha(0.8)
    ax.add_collection3d(poly, zs=alphasC, zdir='y')
    plt.yticks(ticks=[-1, -0.5, 0, 0.5, 1])

    ax.set_title('Spectral Coherance, Conjugate')
    ax.set_xlabel(r'$f\ /\ f_s$')
    ax.set_xlim3d(-0.5, 0.5)
    ax.set_ylabel(r'$\alpha$, cycle freq')
    ax.set_ylim3d(1, -1)
    ax.set_zlabel('Magnitude (linear)')
    ax.set_zlim3d(0, 1)

    plt.show()

def plotResults4(ncQuads, cQuads, tsmType, sc=False):
    '''Plot time domain transmit signal and estimated power spectral density.

    Inputs:
      txSig (complex[]) : BPSK signal to be transmitted.
      freqs (float[]) : frequencies of PSD estimate.
      psdEst (complex[]) : estimated power spectral density.

    Output: none.
    '''

    if len(ncQuads) == 0 or len(cQuads) == 0:
        return

    plt.rcParams['axes.grid'] = True
    f, ax = plt.subplots(2, 1, figsize=(10,10))
    plt.subplots_adjust(hspace=0.7) # Room for lower subplot title.

    # PSD estimate (TSM), non-conjugate.
    title = ('Non-conjugate %s %s, 500 Largest'
        % (tsmType, (' SC' if sc else 'SCF')))
    freq = ncQuads.transpose()[0].real
    alpha = ncQuads.transpose()[1].real
    g = np.abs(ncQuads.transpose()[3])

    ax[0].set_title(title)
    ax[0].scatter(freq, alpha, s=5)
#   ax[0].scatter(freq, alpha, s=5, c=g, cmap='autumn')
    ax[0].set_xlabel('Frequency (Normalized Hz)')
    ax[0].set_ylabel(r'Cycle Frequency $\alpha$ (Normalized Hz)')
    ax[0].set_xlim(-0.5, 0.5)
    ax[0].set_ylim(0, 1)
    ax[0].set_ylim(-1, 1)

    # PSD estimate (FSM), conjugate.
    title = ('Conjugate %s %s, 500 Largest'
        % (tsmType, (' SC' if sc else 'SCF')))
    freq = cQuads.transpose()[0].real
    alpha = cQuads.transpose()[1].real
    g = np.abs(cQuads.transpose()[3])

    ax[1].set_title(title)
    ax[1].scatter(freq, alpha, s=5)
#   ax[1].scatter(freq, alpha, s=5, c=g, cmap='autumn')
    ax[1].set_xlabel('Frequency (Normalized Hz)')
    ax[1].set_ylabel(r'Cycle Frequency $\alpha$ (Normalized Hz)')
    ax[1].set_xlim(-0.5, 0.5)
    ax[1].set_ylim(-1, 1)

    plt.tight_layout() # Adjust spacing between plots.
    plt.show()

def plotSc(ncQuads, cQuads, tsmType, sc=False):
    '''Plot time domain transmit signal and estimated power spectral density.

    Inputs:
      txSig (complex[]) : BPSK signal to be transmitted.
      freqs (float[]) : frequencies of PSD estimate.
      psdEst (complex[]) : estimated power spectral density.

    Output: none.
    '''

    if len(ncQuads) == 0 or len(cQuads) == 0:
        return

    plt.rcParams['axes.grid'] = True
    f, ax = plt.subplots(2, 1, figsize=(10,10))
    plt.subplots_adjust(hspace=0.7) # Room for lower subplot title.

    # PSD estimate (TSM), non-conjugate.
    if sc:
        title = 'Non-conjugate %s SC, 500 Largest' % tsmType
    else:
        title = 'Non-conjugate %s SCF, 500 Largest' % tsmType
    ax[0].set_title(title)
    freq = ncQuads.transpose()[0].real
    alpha = ncQuads.transpose()[1].real
    g = np.abs(ncQuads.transpose()[3])
    ax[0].scatter(freq, alpha, s=5)
#   ax[0].scatter(freq, alpha, s=5, c=g, cmap='autumn')
    ax[0].set_xlabel('Frequency (Normalized Hz)')
    ax[0].set_ylabel(r'Cycle Frequency $\alpha$ (Normalized Hz)')
    ax[0].set_xlim(-0.5, 0.5)
    ax[0].set_ylim(0, 1)
#   ax[0].set_ylim(-1, 1)

    # PSD estimate (FSM), conjugate.
    if sc:
        title = 'Conjugate %s SC, 500 Largest' % tsmType
    else:
        title = 'Conjugate %s SCF, 500 Largest' % tsmType
    freq = cQuads.transpose()[0].real
    alpha = cQuads.transpose()[1].real
    g = np.abs(cQuads.transpose()[3])
    ax[1].set_title(title)
    ax[1].scatter(freq, alpha, s=5)
#   ax[1].scatter(freq, alpha, s=5, c=g, cmap='autumn')
    ax[1].set_xlabel('Frequency (Normalized Hz)')
    ax[1].set_ylabel(r'Cycle Frequency $\alpha$ (Normalized Hz)')
    ax[1].set_xlim(-0.5, 0.5)
    ax[1].set_ylim(-1, 1)

    plt.tight_layout() # Adjust spacing between plots.
    plt.show()

def plotSurface(scf, sc):

    '''From https://matplotlib.org/stable/gallery/mplot3d/surface3d.html
    '''

    from matplotlib import cm
    from matplotlib import colors as mcolors
    from matplotlib.ticker import LinearLocator
    from mpl_toolkits.mplot3d import Axes3D

#   plt.rcParams['axes.grid'] = True
#   fig = plt.figure(figsize=plt.figaspect(0.5))
#   ax = fig.add_subplot(1, 2, 1, projection='3d')

    N, Np = scf.shape
    q = (np.arange(N) - N/2).reshape((N,1))
    k = np.arange(Np) - Np/2
    f = (k/Np - q/N)/2
    alpha = k/Np + q/N

    fig = plt.figure(figsize=plt.figaspect(0.5)) # 2 x 1 aspect ratio.

    #   S u b - p l o t   1
    #
    ax = fig.add_subplot(1, 2, 1, projection='3d')
    ax.set_title('SSCA SCF, Non-conjugate')
    ax.set_xlabel(r'$f\ /\ f_s$')
    ax.set_xlim3d(-0.5, 0.5)
    ax.set_ylabel(r'$\alpha$, cycle freq')
    ax.set_ylim3d(-1, 1)
    ax.set_zlabel('Magnitude (linear)')
    ax.set_zlim3d(0, 1)

    z1 = np.abs(scf)
    surf = ax.plot_surface(f, alpha, z1, cmap=cm.coolwarm, antialiased=False)
    ax.set_zlim(0, 3)
    fig.colorbar(surf, shrink=0.5, aspect=20)

    #   S u b - p l o t   2
    #
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.set_title('SSCF SC, Non-conjugate')
    ax.set_xlabel(r'$f\ /\ f_s$')
    ax.set_xlim3d(-0.5, 0.5)
    ax.set_ylabel(r'$\alpha$, cycle freq')
    ax.set_ylim3d(-1, 1)
    ax.set_zlabel('Magnitude (linear)')
    ax.set_zlim3d(0, 1)
    z2 = np.abs(sc)
    surf = ax.plot_surface(f, alpha, z2, cmap=cm.coolwarm, antialiased=False)
    ax.set_zlim(0, 1)
    fig.colorbar(surf, shrink=0.5, aspect=20)

    plt.show()

if __name__ == '__main__':
    main()
