#!/usr/bin/env python

'''
This program takes an arbitrary binary IQ signal or Matlab 5.0 file as input
and analyses it using the techniques present in Chad Spooner's CSP blog at
https://cuclostationary .  Not additional mathematics is performed that is
not shown there.  However, the program groups the plots together, and uses
canned text to create a report.  The report is not complete and is simply raw
CSP data.  The person analyzing the signal must augment the report with
interpretation and description of what each of the major cycles are.

Please excuse the poor commenting.  I hope to correct that!  Till then, start
with main() at the bottom of this file.

Mike Markowski
mike.ab3ap@gmail.com
2021
'''

import cew.util as util
import cspPlot
import libcsp as csp
import numpy as np
import os
import scipy.io as sio
import sys
import time
import tex

def analyze(config, maxBlind=500, maxScf=10):

    # Retrieve configuration from dictionary.
    blindType = config['blind']
    scfType = config['scf']
    sigFile = config['sig']
    threshSc = config['thresh']
    w = config['win']

    fs_Hz, fc_Hz, bw_Hz, sig = readSig(sigFile)

    # Blind detection quads are: [f, alpha, scf, sc].
    quadsN, quadsC = blind(blindType, threshSc, sig, maxBlind)
    alphasN = quadsN[:, 1] # Cycle frequencies.
    alphasC = quadsC[:, 1]
    scN = quadsN[:, 2]     # [:,3] coherences, [:,2] SCFs.
    scC = quadsC[:, 2]
    alphaC, scC = csp.binAlpha(alphasC, scC) # Unique alphas, max SC.
    alphaN, scN = csp.binAlpha(alphasN, scN)
    alphaC = alphaC[alphaC>=0] # Remove symmetric duplicates.
    alphaN = alphaN[alphaN>0] # Remove PSD and symmetric duplicates.

    # Calculate spectral coherences for highest blind SCFs.
    scC = []
    scN = []
    for a in alphaC[:maxScf]:
        scC.append(csp.sc(sig, a, conj=True))
    for a in alphaN[:maxScf]:
        scN.append(csp.sc(sig, a))
    scC = np.array(scC)
    scN = np.array(scN)

    # Study best cycle frequencies. tsmBlock used when scfType=='tsm'.
    scfN, w = scf(scfType, sig, alphaN, conj=False, fs_Hz=fs_Hz, win=w)
    scfC, w = scf(scfType, sig, alphaC, conj=True,  fs_Hz=fs_Hz, win=w)

    results = {} # Dictionary of results.
    results['alphaC'] = alphaC # Top conjugate cycle freqs.
    results['alphaN'] = alphaN # Top non-conj cycle freqs.
    results['bw_Hz'] = bw_Hz # Signal bandwidth retrieved from signal file.
    results['fc_Hz'] = fc_Hz # Center frequency retrieved from signal file.
    results['fs_Hz'] = fs_Hz # Sample rate retrieved from signal file.
    results['maxBlind'] = maxBlind
    results['maxScf'] = maxScf
    results['quadsC'] = quadsC # Conjugate blind results, all.
    results['quadsN'] = quadsN # Non-conjugate blind detection, all.
    results['scC'] = scC # Top conjugate spectral coherences.
    results['scN'] = scN # Top non-conjugate spectral coherences.
    results['scfC'] = scfC # Top conjugate spectral coherences.
    results['scfN'] = scfN # Top non-conjugate spectral coherences.
    results['win'] = w
    return results

def blind(blindType, thresh, sig, nMax=500, sortby='sc'):
    if blindType == 'ssca':
        Np = 64 # Number of strips (N*Np points to process).
        # SCF blind estimates.
        fN, aN, sscaN = csp.ssca(sig, Np, conj=False)
        fC, aC, sscaC = csp.ssca(sig, Np, conj=True)
        # Coherences of SCFs.
        scN = csp.sscaSc(sig, sscaN, conj=False)
        scC = csp.sscaSc(sig, sscaC, conj=True)
        # Retrieve highest coherences.
        quadsN = csp.filter(fN, aN, scN, sscaN, threshold=thresh,
            top=nMax, sortby=sortby)
        quadsC = csp.filter(fC, aC, scC, sscaC, threshold=thresh,
            top=nMax, sortby=sortby)
    else: # blindType == 'fam'.
        L = 8 # Number of samples to slide window.
        # SCF blind estimates.
        fN, alphaN, famN = csp.fam(sig, L, conj=False)
        fC, alphaC, famC = csp.fam(sig, L, conj=True)
        # Coherences of SCFs.
        scNF = csp.famSc(sig, fN, alphaN, famN, conj=False)
        scCF = csp.famSc(sig, fC, alphaC, famC, conj=True)
        # Retrieve highest coherences.
        quadsN = csp.filter(fN, alphaN, scNF, famN, threshold=thresh,
            top=nMax, sortby=sortby)
        quadsC = csp.filter(fC, alphaC, scCF, famC, threshold=thresh,
            top=nMax, sortby=sortby)
    return quadsN, quadsC

def cmdLine(argv):

    # Defaults.
    blind = 'ssca'
    outDir = ''
    scf = 'tsm'
    sigFile = ''
    thresh = 0.1
    win = None

    i = 0
    while i < len(argv):
        arg = argv[i]
        if arg == '-b': # Blind detection, default 'ssca'.
            i += 1
            blind = argv[i].strip().lower()
            if blind not in ['fam', 'ssca']:
                usage()
        elif arg == '-i': # Input file.
            i += 1
            sigFile = argv[i]
        elif arg == '-o': # Output file.
            i += 1
            outDir = argv[i]
        elif arg == '-s': # SCF type, default 'fsm'.
            i += 1
            scf = argv[i].strip().lower()
            if scf not in ['fsm', 'tsm']:
                usage()
        elif arg == '-t': # Spectral coherance threshold.
            i += 1
            try:
                thresh = float(argv[i])
            except ValueError:
                print('threshold -t \'%s\' must be float.' % argv[i])
                usage()
        elif arg == '-w': # FSM smoothing window or TSM block size.
            i += 1
            try:
                win = int(argv[i])
            except ValueError:
                print('window -w \'%s\' must be integer.' % argv[i])
                usage()
        i += 1
    if outDir == '' or sigFile == '':
        if outDir == '':
            print('Missing output directory name.')
        if sigFile == '':
            print('Missing signal file name.')
        print('')
        usage()

    # Generate command to recreate results.
    cmd = '%s -i %s -o %s ' % (os.path.basename(argv[0]), sigFile, outDir)
    if blind != 'ssca':
        cmd += '-b %s ' % blind
    if scf != 'tsm':
        cmd += '-s %s ' % scf
    if thresh != 0.1:
        cmd += '-t %s ' % thresh
    if win != None:
        cmd += '-w %d ' % win

    # Create configuration dictionary as return value.
    config = {}
    config['blind'] = blind
    config['cmd'] = cmd
    config['out'] = outDir
    config['scf'] = scf
    config['sig'] = sigFile
    config['thresh'] = thresh
    config['win'] = win
    return config

def isMat(fname):
    '''Return True/False if file is/isn't a Matlab signal file.
    '''
    f = open(fname, 'rb')
    h = f.read(10).decode()
    f.close()
    return h == 'MATLAB 5.0'

def readBinary(rawfile):
    '''Return complex i/q signal data stored in recording file.

    File name must contain substrings: *_fc_<Hz>_*  and  *_fs_<Hz>_*

    where where <Hz> are replaced by appropriate real numbers in Hz.
    An example file name is 'wrff_fs_500e3_fc_104.5e6_.iq'. 
    '''

    i = 0
    fields = rawfile.split('_')
    fs_Hz = fc_Hz = -1
    while i < len(fields):
        if fields[i] == 'fs':
            i += 1
            try:
                fs_Hz = float(fields[i])
            except ValueError:
                print('Improper sample rate: \'%s\''  % fields[i])
        elif fields[i] == 'fc':
            i += 1
            try:
                fc_Hz = float(fields[i])
            except ValueError:
                print('Improper center frequency: \'%s\''  % fields[i])
        i += 1

    if fs_Hz == -1 or fc_Hz == -1: # Improper file name.
        if fc_Hz == -1:
            print('File name omits substring _fc_VAL_, assuming fc=0.')
            fc_Hz = 0
        if fs_Hz == -1:
            print('File name omits substring _fs_VAL_, assuming fs=1.')
            fs_Hz = 1
    iq = np.fromfile(rawfile, dtype=complex)
    return fs_Hz, fc_Hz, fs_Hz, iq # 2nd fs_Hz is used as bw_Hz.

def readMat(fname):

    bw_Hz = fc_Hz = fs_Hz = iq = None
    hdr = sio.loadmat(fname) # Dictionary of header var/vals.
    for var in hdr.keys():
        if var == 'FreqValidMax':
            f1_Hz = hdr[var][0][0]
        elif var == 'FreqValidMin':
            f0_Hz = hdr[var][0][0]
        elif var == 'InputCenter':
            fc_Hz = hdr[var][0][0]
        elif var == 'Span':
            bw_Hz = hdr[var][0][0]
        elif var == 'XDelta':
            fs_Hz = 1/hdr[var][0][0]
        elif var == 'Y':
            iq = hdr[var].flatten() # I/Q data of signal.
    if iq is None:
        print('Signal data not found in variable Y.  Quitting.')
        print('Variables found:')
        for key in keys:
            print('  %s' % key)
        sys.exit(1)
    if bw_Hz is None:
        bw_Hz = f1_Hz - f0_Hz
    if fc_Hz is None:
        fc_Hz = (f0_Hz + f1_Hz)/2
    return fs_Hz, fc_Hz, bw_Hz, iq

def readSig(filename):
    if isMat(filename):
        fs_Hz, fc_Hz, bw_Hz, iq = readMat(filename)
    else: # Assume binary i/q.
        fs_Hz, fc_Hz, bw_Hz, iq = readBinary(filename)
    return fs_Hz, fc_Hz, bw_Hz, iq

def report(cfg, res, cpu_s):
    # From configuration.
    bl = cfg['blind']
    cmdRegen = cfg['cmd']
    do = cfg['out']
    fi = cfg['sig']
    sc = cfg['scf']
    th = cfg['thresh']

    fs_Hz, fc_Hz, bw_Hz, sigIq = readSig(fi)

    try:
        os.mkdir(do) # Create output directory.
    except FileExistsError:
        pass         # Ok if directory exists.

    #
    #   C r e a t e   F i l e n a m e s
    #

    # Filenames of plots.
    blindCfile = 'blindC.png' # Scatter, blind conjugate.
    blindNfile = 'blindN.png' # Scatter, blind non-conjugate.
    cyclesCfile = 'cyclesC.png'
    cyclesNfile = 'cyclesN.png'
    scfCfile = 'scfC.png'     # Curve, SCF conjugate.
    scfNfile = 'scfN.png'     # Curve, SCF non-conjugate.
    sigFfile = 'sigF.png'     # Freq spectrum of signal.
    sigTfile = 'sigT.png'     # Time vs mag of signal.
    slicesCfile = 'slicesC.png'
    slicesNfile = 'slicesN.png'

    # Creat absolute path names.
    blindCAbs = os.path.join(do, blindCfile) # Scatter, blind conjugate.
    blindNAbs = os.path.join(do, blindNfile) # Scatter, blind non-conjugate.
    cyclesCAbs = os.path.join(do, cyclesCfile)
    cyclesNAbs = os.path.join(do, cyclesNfile)
    scfCAbs = os.path.join(do, scfCfile)     # Curve, SCF conjugate.
    scfNAbs = os.path.join(do, scfNfile)     # Curve, SCF non-conjugate.
    sigFAbs = os.path.join(do, sigFfile)     # Freq spectrum plot of signal.
    sigTAbs = os.path.join(do, sigTfile)     # Time plot of signal.
    slicesCAbs = os.path.join(do, slicesCfile)
    slicesNAbs = os.path.join(do, slicesNfile)

    #
    #   C r e a t e   P l o t s
    #

    alphaC = res['alphaC'] # Top conjugate cycle freqs.
    alphaN = res['alphaN'] # Top non-conj cycle freqs.
    bw_Hz  = res['bw_Hz']  # Signal bandwidth retrieved from signal file.
    fc_Hz  = res['fc_Hz']  # Center frequency retrieved from signal file.
    fs_Hz  = res['fs_Hz']  # Sample rate retrieved from signal file.
    maxBlind = res['maxBlind']
    maxScf = res['maxScf']
    quadsC = res['quadsC'] # Conjugate blind results, all.
    quadsN = res['quadsN'] # Non-conjugate blind detection, all.
    scC    = res['scC']    # Top conjugate spectral coherences.
    scN    = res['scN']    # Top non-conjugate spectral coherences.
    scfC   = res['scfC']   # Top conjugate spectral coherences.
    scfN   = res['scfN']   # Top non-conjugate spectral coherences.
    win    = res['win']

    # Make plots to use in report.
    cspPlot.spectrum(sigIq, fc_Hz, fs_Hz, foutTime=sigTAbs, foutFreq=sigFAbs)
    cspPlot.scatter(quadsC, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=blindCAbs)
    cspPlot.scatter(quadsN, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=blindNAbs)
    cspPlot.cycles(alphaC, scfC, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=cyclesCAbs)
    cspPlot.cycles(alphaN, scfN, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=cyclesNAbs)
    cspPlot.slices(alphaC[:maxScf], scC, fc_Hz, fs_Hz, fout=slicesCAbs)
    cspPlot.slices(alphaN[:maxScf], scN, fc_Hz, fs_Hz, fout=slicesNAbs)

    keyC = []
    keyN = []
    top = 5
    for alpha in alphaC[:top]:
        _, xPwr10, _ = util.engNot(fs_Hz)
        xPre = util.siPrefix(xPwr10)
        a = (fs_Hz*alpha)/10**xPwr10
        keyC.append('$\\alpha = %.2f$ %sHz' % (a, xPre))
    for alpha in alphaN[:top]:
        _, xPwr10, _ = util.engNot(fs_Hz)
        xPre = util.siPrefix(xPwr10)
        a = (fs_Hz*alpha)/10**xPwr10
        keyN.append('$\\alpha = %.2f$ %sHz' % (a, xPre))
    cspPlot.scf(scfC[:top], key=keyC, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=scfCAbs)
    cspPlot.scf(scfN[:top], key=keyN, fc_Hz=fc_Hz, fs_Hz=fs_Hz, fout=scfNAbs)

    #
    #   C r e a t e   L a T e X   R e p o r t
    #

    cspTex = 'csp.tex'
    cspTexAbs = os.path.join(do, cspTex)
    r = tex.LaTeX(cspTexAbs)
    r.beginning()
    r.intro(fi, sigTfile, sigFfile, cmdRegen)
    r.blindScf(blindCfile, blindNfile, bl, maxBlind, th)
    r.cycles(cyclesCfile, cyclesNfile, alphaC, scfC, alphaN, scfN,
        fc_Hz=fc_Hz, fs_Hz=fs_Hz)
    r.slices(slicesCfile, slicesNfile, sc, th, win, maxScf)
    r.scf(scfCfile, scfNfile, sc, th, win, top)
    r.ending(cpu_s)
    cmd = '(cd %s; pdflatex %s; pdflatex %s)' % (do, cspTex, cspTex)
    cmd += ' >/dev/null 2>&1' # Suppress stdout and stderr.
    os.system(cmd)

def scf(scfType, sig, alpha, conj=False, fs_Hz=1, win=None):

    if not scfType in ['fsm', 'tsm']:
        print('scpScf: scfType must be fsm or tsm, got %s.' % scfType)
        return None

    scfs = []
    for alpha in alpha:
        if False:
            s = csp.sc(sig, alpha, conj)
        else:
            if scfType == 'fsm':
                s = csp.scfFsm(sig, alpha, conj)
                w = int(0.01*sig.size) if win == None else win
                s = csp.smooth(s, win)
            else: # scfType == 'tsm'
                w = 512 if win == None else win
                s = csp.scfTsm(sig, w, alpha, conj)
        scfs.append(fs_Hz*s)
    return scfs, w

def usage():
    print('Usage: csp [-b fam|ssca] -i sigFile -o outDir ', end='')
    print('[-s fsm|tsm] [-t num] [-w win]')
    print('')
    print('-b: blind estimation method, default ssca.')
    print('-i: matlab of binary i/q signal file name.')
    print('-o: directory where CSP analysis results will be written.')
    print('-s: type of spectral correlation function to use, default fsm.')
    print('-t: spectral coherance threshold, default 0.1.')
    print('-w: FSM smoothing window or TSM block, in samples.')
    sys.exit(1)

#
#   m a i n
#

def main(argv):

    config = cmdLine(argv)     # Dictionary of settings.
    t0 = time.process_time()
    res = analyze(config)
    t1 = time.process_time()
    cpu_s = t1 - t0
    report(config, res, cpu_s) # Generate LaTeX doc.

if __name__ == '__main__':
    main(sys.argv)
