#!/usr/local/anaconda3/bin/python # This code implements mathematics presented by Chad Spooner in # his blog at https://cyclostationary.blog/ # # SSCA code is implemented from Eric April's paper, "On the Implementation # of the Strip Spectral Correlation Algorithm for Cyclic Spectrum Estimation" # by Eric April. 1995. # # Subroutines are in alphabetical order. # # Variable names use format thing_units, where underscore separates name from # its units. E.g., fc_Hz is carrier freq in Hz. # # Mike Markowski # mike.ab3ap@gmail.com # Mar 2021 from numpy import cos, fft, log2, log10, pi, sqrt from scipy import signal from scipy.interpolate import interp1d import numpy as np import random import sys import time def binAlpha(alphas, scfs): '''For unique alphas, find maximum SCFs. Alphas are rounded 4 decimal places so that very close values become identical. Every alphas[i] corresponds to scfs[i]. ''' ar = np.round(alphas, 4) # Round to nearest 1e-4. a = np.unique(ar) # Remove duplicates and sort. s = np.zeros(a.size) for i in range(a.size): # XXX Is there better way than 'for' loop?? scfInds = np.where(ar == a[i]) s[i] = max(np.abs(scfs[scfInds])) ind = np.argsort(s)[::-1] # Index order to reverse sort s. return a[ind], s[ind] # Unique, rounded alphas with max SCFs. def dB_to_W(val_dB): return 10**(val_dB/10) def fam(sig, L, Np=0, conj=False, showMults=False): '''FFT Accumulation Method, computes spectral correlation estimates over its entire principal domain. It does the same thing as ssca(). [1] Implementation based on 'Computationally Efficient Algorithms for Cyclic Spectral Analysis,' Roberts, Brown and Loomis, IEEE SP Magazine, Apr 1991. Also, drawn from: [2] https://cyclostationary.blog/2018/06/01/csp-estimators-the-fft-accumulation-method/ [3] "Implementation of Cyclic Spectral Analysis Methods," LCDR Nancy J. Carter, 1992, Naval Postgraduate School. Equation numbers come from the blog page, [2], above. Inputs: x (complex[] : signal of interest. L (int, dyadic) : hop length, in units of samples. Np (int, dyadic) : length of FFTs. N'=4L is recommended. Ex. N = 65536 N' = 32 L = 8 ''' N = 2**int(log2(sig.size)) if N != sig.size: print('Using first %d dyadic signal samples of %d.' % (N, sig.size)) if not famReady(sig[:N], L, Np): # Assortment of sanity checks. return [] if Np == 0: Np = 4*L # Recommended trade-off in [1]. # N = sig.size P = N//L # Eq 4, will do Np P-point FFTs. pad = Np + L*(P - 1) - N # Pad x() with to fill matrix. x = np.append(sig, np.zeros(pad)) # x is zero padded sig. y = x.conjugate() if conj else x N = x.size # Updated signal length. # Step 1: Arrange N'-point Data Sub-blocks. # [x(0) x(L+0) x(2L+0) ... x((P-1)L + 0)] # [x(1) x(L+1) x(2L+1) ... x((P-1)L + 1)] # [x(N') x(L+N') x(2L+N') ... x((P-1)L + N')] colNp = np.arange(Np).reshape((Np, 1)) rowP = np.arange(P) X = x[colNp + L*rowP] # Np x P, each col is x' staggered by L samples. Y = y[colNp + L*rowP] if conj else X # Step 2: Apply Data Tapering Window to Columns of X. A = signal.get_window('hamming', Np) # a(r) in Eq 3. rms = lambda x: sqrt(np.mean(x**2)) A *= rms(A/2) # XXX Manual tweaking. Why '/2'?? A = A.reshape((Np, 1)) # Np x 1. XA = X*A # Np x P, tapered. YA = Y*A if conj else XA # Step 3: Apply Fourier Transform to Windowed Subblocks. XAT = fft.fftshift(fft.fft(XA, axis=0), axes=0) # Column FFTs, Eq 3. YAT = fft.fftshift(fft.fft(YA, axis=0), axes=0) if conj else XAT # Mix demodulates to baseband. f = fft.fftshift(fft.fftfreq(Np)) # Spectral components of XAT. f = f.reshape((Np, 1)) # Np x 1. q = np.arange(P)*L # 1 x P, scale frequency by stagger. E = np.exp(-2j*pi*f*q) # Np x P, phase adjustment. Xg = XAT*E # Np x P, demodulates mixed to baseband. Yg = YAT*E if conj else Xg # Step 4: Multiply Channelized Subblocks Together and FFT. # Sx: spectral correlation estimate, Np x P matrix. Sx = np.array([Xg[k]*Yg.conj() for k in range(Np)]) # Np**2 x P. Sx = Sx.reshape((Np*Np, P)) Sx = fft.fftshift(fft.fft(Sx), axes=1) # Row FFTs. Sx /= P # Rectangular smoothing window. e = P//4 Sx = Sx[:, e:3*e] # Save center half of FFT, locations 1/4 to 3/4. # Step 5: Associate Fourier Transform Outputs with Freqs. # Spectral (f_j) and cycle (alpha_i) frequencies. f_j = np.array([(f[k] + f)/2 for k in range(Np)]) # Eq 7, 1 x Np*Np. f_j = f_j.reshape((Np*Np, 1)) # Np*Np x 1. alpha_i = np.array([(f[k] - f) for k in range(Np)]) # Eq 6, 1 x Np*Np. alpha_i = alpha_i.reshape((Np*Np, 1)) # Np*Np x 1. q = np.arange(-P//4, P//4) # Cycles about alpha_i. alpha = alpha_i + q/N # Np*Np x P, Eqs 5 & 8. if showMults: # From [3], p. 12. m = (6 + 4*Np)*P*Np + (2*P*Np)*(log2(Np) + Np*log2(P)) print('FAM multiplications: %.1f (x1e6)' % (m/1e6)) return f_j, alpha, Sx # Np*Np x 1, Np*Np x P, Np*Np x P. def famReady(x, L, Np=0): if Np == 0: Np = 4*L # Some sanity checks. if not L < Np: print('FAM: Need L < N\', but %d >= %d. Exiting.' % (L, Np)) return False elif L == Np: s = 'FAM warning: L == N\' == %d yields substantial cycle leakage.' % L print(s) return True # Ensure N and L are dyadic. N = x.size N2 = log2(N) L2 = log2(L) if int(L2) != L2: print('FAM: non-dyadic hop length L=%d. Exiting.' % L) return False if int(N2) != N2: print('FAM: non-dyadic signal length N=%d. Exiting.' % N) return False return True def famResolution(N, Np, q): # Resolution of result. 'd' is short for delta, used in paper. fs = 1 # Normalized fs, used to keep equations general. dt = N/fs # Delta T, length of signal. da = fs/Np # Resolution determined by tapering window. dalpha = fs/N # Cycle freq resoln deps on points processed. df = da - np.abs(q)*dalpha # Freq resolution is function of q. dtdf = dt*da # dt*df when q == 0. return df, dalpha, dtdf def famSc(x, f, alpha, famScf, conj=False, M=64): '''Convert SSCA spectral correlation function (SCF) output to spectral coherances (SC). Inputs: x (complex[]) : 1d signal whose PSD is used in SC calculations. famScf (complex[]) - Np x P matrix if SCF value from SSCA. M (int) : SCF TSM block size, default 64. Output: (complex[][] : Np x P matrix of spectral coherance values corresponding to each element in famScf[][]. ''' Np, P = famScf.shape # Calculate PSD and function to interpolate it. if True: # SCF TSM psd = scfTsm(x, M, alpha=0, taper=False) else: # SCF FSM zPad = 4 psd = scfFsm(x, alpha=0, padFactor=zPad) psd = smooth(psd, zPad*64) fP = fft.fftshift(fft.fftfreq(psd.size)) # PSD spectral components. fn = interp1d(fP, psd, fill_value=(psd[0],psd[-1]), bounds_error=False, kind='nearest') # Prepare roll up/down values and interpolate rolled PSDs. dn = f + alpha/2 # Np x P. up = alpha/2-f if conj else f-alpha/2 # Np x P. dnI = fn(dn) # 1 x P row. upI = fn(up) # Np x 1 col. denom = sqrt(dnI*upI) # Np x P matrix, z = np.where(denom < 1e-10) # Div by 0 (or nearly) locations. famScf[z] = 0 # Return 0 coherance for divisons by 0. denom[z] = 1 return famScf/denom # Spectral coherences, Np x P matrix. def filter(f, alpha, sc, scf, threshold=0, top=0, alpha0=None, sortby='sc'): '''Return matrix whose rows are [alpha, f, sc, scf]. Performs Step 5 from April's paper: "On the Implementation of the Strip Spectral Correlation Algorithm for Cyclic Spectrum Estimation" by Eric April. 1995. Inputs: f (float[]) - spectral freq alpha (float[]) - cycle freq sc (complex[]) - spectral coherence matrix. scf (complex[]) - spectral correlations corresponding to sc[]. Outputs: ([[f, alpha, C, S],]) where C, coherance S, spectral correlation Each strip FFT contributes to PSD, (a=0) Each FFT has N points, each point a cycle frequency ''' rows, cols = scf.shape # rows x cols. scAbs = np.abs(sc) # rows x cols, complex to real magnitude. scfAbs = np.abs(scf) # rows x cols. if f.shape != scf.shape: # FAM. fp = np.tile(f, cols) # Turn rows x 1 into rows x cols. fam = True else: # SSCA. fp = f fam = False # Create quads of [spectral freq, cyclic freq, SCF, spectral coh]. afsc = np.zeros((rows*cols, 4)) afsc[:, 0] = fp.flatten() afsc[:, 1] = alpha.flatten() afsc[:, 2] = scfAbs.flatten() afsc[:, 3] = scAbs.flatten() col = 3 if sortby == 'sc' else 2 if fam: # Don't waste time unnecessarily doing this on ssca results. # For FAM, save normalized freqs where f +/- alpha2 in [-0.5, 0.5). afsc = afsc[afsc[:,0] + afsc[:,1]/2 >= -0.5] afsc = afsc[afsc[:,0] + afsc[:,1]/2 <= 0.5] afsc = afsc[afsc[:,0] - afsc[:,1]/2 >= -0.5] afsc = afsc[afsc[:,0] - afsc[:,1]/2 <= 0.5] if threshold > 0: afsc = afsc[afsc[:,col] > threshold] # Ignore low SCs. if top == 0: # Return 'top' number of rows. top = sc.size # afsc: N*rows x 4 rows of [ [alpha, f, sc, scf], ...] if alpha0 == None: # Return 'top' SCF or SC values. rev = True # Sort high to low. afsc = np.ndarray.tolist(afsc) # Can only sort lists. afsc.sort(key=lambda row: row[col:], reverse=rev) afsc = np.array(afsc[:top]) else: # Return rows close to specified alpha. Usually, for debug plots. col = 1 # Col 1 is alpha. i = np.where(np.abs(afsc[:,col] - alpha0) < 1/(2*rows)) afsc = afsc[i] afsc = np.ndarray.tolist(afsc) # Can only sort lists. afsc.sort(key=lambda row: row[0:]) # Sort by spectral freq. afsc = np.array(afsc) return afsc def hamming(n): '''Create a Hamming window containing n points. Input: n (int) : number of points in Hamming window. Output: (float[]) : array of Hamming window magnitudes. ''' return 0.53836 - 0.46164*cos(2*pi*np.arange(n)/n) def mapRange(from0, from1, to0, to1, valFrom): '''Map a value from an input range to an output range. Plain old 2D linear interpolation. Inputs: from0, from1 (float) : start, stop of input range. to0, to1 (float) : start, stop of output range. valFrom (float) : value to be mapped. ''' frac = (valFrom - from0)/(from1 - from0) # Normalize to 'from' range. vTo = to0 + frac*(to1 - to0) # Un-normalize to 'to' range. return vTo def periodogram(sig, fftSize): '''Estimate a signal's PSD with Daniell method, a frequency-smoothed periodogram. See: https://cyclostationary.blog/2015/11/20/csp-estimators-the-frequency-smoothing-method/ Equation numbers in comments below refer to the above web page (as viewed in Mar 2021). This subroutine implements Eq 2. Inputs: sig (complex[]) : symbol stream whose PSD estimate is wanted. fftSize (int) : number of frequencies in PSD estimate. Outputs: (float[]) : frequencies of PSD estimate. (complex[]) : periodogram of sig. ''' # Perform CSP on received signal. I = fft.fft(sig) I = fft.fftshift(I) # Move DC to center. I = I*I.conjugate()/I.size # Eq 2, |X(f)|^2 / N. return smooth(I, 0.005*I.size) # Smoothing window 0.5% of signal. def power_W(sig_V): return np.mean(np.abs(sig_V)**2) def quickPlot(data1, data2=None, title='', xlabel='', ylabel='', key=''): import matplotlib.pyplot as plt plt.grid(True) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) if key != '': plt.legend(key) if type(data2) == np.ndarray: plt.plot(data1, data2) else: plt.plot(data1) plt.show() def sc(sig, alpha, conj=False, pulseWidth=0, padFactor=1): '''Calculate spectral coherance of a signal at specified cycle frequency. https://cyclostationary.blog/2016/01/08/the-spectral-coherence-function/ Inputs: sig (complex[]) : signal to be analyzed. alpha (float) : cycle frequency. conj (boolean) : want conjugate/non-conjugate cyclic periodogram. Default is False. pulseWidth (int) : width, in array elements, of smoothing pulse. When left at default of 0, is then set to len(sig)//100. ''' if pulseWidth == 0: pulseWidth = sig.size//100 # Recommended in Chad's blog. scf = scfFsm(sig, alpha, conj, padFactor=padFactor) scf = smooth(scf, padFactor*pulseWidth) psd = scfFsm(sig, alpha=0, padFactor=padFactor) # SCF for alpha=0 is PSD. # psd = smooth(psd, padFactor*pulseWidth) a1, a2 = shifts(alpha, psd.size, conj) # Up/down shifts for PSD. up = zshift(psd, -a1) # Shift left a1 elements and zero fill. if conj: psd = psd[::-1] # X(-f) dn = zshift(psd, a2) # Shift right a2 elements and zero fill. denom = sqrt(smooth(up, pulseWidth)*smooth(dn, pulseWidth)) z = np.where(np.abs(denom) == 0) # Div by 0 locations. scf[z] = 0 # Return 0 coherance for divisons by 0. denom[z] = 1 return scf/denom # Spectral coherance. def scfFsm(sig, alpha, conj=False, padFactor=4): '''Spectral Correlation Function: Frequency Smoothing Method. Inputs: sig (complex[]) : signal to be analyzed. alpha (float) : cycle frequency of interest. conj (boolean) : want conjugate/non-conjugate cyclic periodogram. padFactor (int) : Chad Spooner illustrates benefits of zero padding at https://cyclostationary.blog/2021/05/05/zero-padding-in-spectral-correlation-estimators/#more-10253 where he shows that a factor of 2 or 4 is best. ''' if padFactor == 1: x = sig else: x = sig.copy() x.resize(x.size*padFactor) I = fft.fft(x) # Freq domain of signal. I = fft.fftshift(I) # Move DC to center. a1, a2 = shifts(alpha, I.size, conj) # Calculate optimal shifts for SCF. up = zshift(I, -a1) # X(f+a1), shift data up. if conj: # Eq 11. dn = I[::-1] # X(-f) else: # Eq 8. dn = I.conjugate() # X*(f) dn = zshift(dn, a2) # X(f-a2) or X(a2-f) for conjugate. I = up*dn/I.size # Cyclic periodogram. return I # SCF. def scfTsm(sig, N, alpha, conj=False, taper=False): '''Spectral Correlation Function, Time Smoothing Method. Implementation of: https://cyclostationary.blog/2015/12/18/csp-estimators-the-time-smoothing-method/ Compute the cyclic periodogram for blocks in time domain and average results. ''' S = zeroPad(sig, N) # Zero pad till sig is multiple of N. M = S.size//N # M blocks of N points. S = S.reshape((M, N)) # Prepare for M rows of N-pt FFTs. if taper: # Optional taper recommended by Chad Spooner. rms = lambda x: sqrt(np.mean(x**2)) w = signal.get_window('hamming', N) S *= w*rms(w) # XXX Unsure that rms scale is best...but works. I = fft.fftshift(fft.fft(S), axes=1) # Row FFTs. a1, a2 = shifts(alpha, N, conj) # Find optimal up/down shifts. up = zshift(I, -a1) # Upward shift. if conj: dn = np.flip(I, axis=1) else: dn = I.conjugate() dn = zshift(dn, a2) # Downward shift. I = up*dn/N # Each row is a cyclic periodogram. u = (np.arange(M)*N).reshape((M, 1)) I *= np.exp(-2j*pi*alpha*u) # Phase compensation. S = np.sum(I, axis=0) # Sum columns, frequency components. return S/M # SCF estimate. def scfTsmLoop(sig, N, alpha, conj=False): '''XXX Works, but much slower than matrix version. Typical comparison: 64 point TSM ffts Loop: 82.3 ms Matrix: 1.4 ms, Speed up: 57.3x 64 point TSM ffts Loop: 85.2 ms Matrix: 1.7 ms, Speed up: 50.2x Spectral Correlation Function, Time Smoothing Method. Implementation of: https://cyclostationary.blog/2015/12/18/csp-estimators-the-time-smoothing-method/ Compute the cyclic periodogram for blocks in time domain and average results. ''' M = sig.size//N # M blocks of N points. S = np.zeros(N, dtype=complex) # Eventually is the desired SCF. for i in range(M): # Loop through M segments in time. u = i*N # Left edge of i'th subblock. I = scfFsm(sig[u:u+N], alpha, conj) # u'th cyclic periodogram. I *= np.exp(-2j*pi*alpha*u) S += I return S/M def shifts(alpha, n, conj=False): '''Calculate upward and downward shifts for cyclic periodogram by minimizing |alpha*n - |a1| - |a2||. Inputs: alpha (float) : cycle frequency. n (int) : length of signal to shift. conj (boolean) : calculate shifts for conjugate or non-conjugate. Outputs: up, down (int, int) : upward/downward shifts. ''' if alpha == 0: # For PSD. up = dn = 0 else: alphan = alpha*n halfAlpha = alphan/2 a1 = int(np.floor(halfAlpha)) # Amount to shift upward. a2 = int(np.ceil(halfAlpha)) # Amount to shift downward. m1 = abs(abs(alphan) - abs(a1) - abs(a2)) m2 = abs(abs(alphan) - abs(a1) - abs(a1)) m3 = abs(abs(alphan) - abs(a2) - abs(a2)) # Choose up and down shifts correponding to min(m1, m2, m3). if min(m1, m2, m3) == m1: up = a1 dn = a2 elif min(m1, m2, m3) == m2: up = a1 dn = a1 else: # min(m1, m2, m3) == m3 up = a2 dn = a2 if conj and n % 2 == 0: # Conjugate and even length signal. dn += 1 # Needed due to flipping of array. return up, dn def smooth(sig, pulseWidth=None): '''Smooth a signal with a unit area pulse. Inputs: pulseWidth (int) - number of elements of sig array to smooth. sig (complex[]) - signal to smooth. ''' if pulseWidth == None: pulseWidth = int(0.01*sig.size) pulse = np.ones(pulseWidth)/pulseWidth # Unit area rectangle. return np.convolve(pulse, sig, mode='same') # Smoothed signal. def ssca(x, Np, N=None, conj=False, showMults=False): '''Calculate Strip Spectral Correlation Analyzer of a signal. From: Implementation directly from Section 3.2, and variables in code mirror those in paper: [1] "On the Implementation of the Strip Spectral Correlation Algorithm for Cyclic Spectrum Estimation" by Eric April. 1995. [2] "Implementation of Cyclic Spectral Analysis Methods," LCDR Nancy J. Carter, 1992, Naval Postgraduate School. Inputs: x (complex[]) - signal to perform SSCA on. Must be N+Np samples long. N (int) - number of points to analyze. Np (int) - number of channels in the channelizer. conj (boolean) : take conjugate/non-conjugate SSCA. Default is False. Outputs: ([[...],[...],...) - N x Np SCF of input of signal. ''' # Step 1. Create sliding windowed vector of input signal. if N == None: N = x.size - Np if x.size < N + Np: # Quick error check. print('SSCA: signal length %d < N+Np=%d.' % (x.size, (N+Np))) N = x.size - Np print(' Setting N to %d.' % N) #return None,None,None # Construct X, Eqn 26. # The transpose in # X = np.array([x[i:i+N] for i in range(Np)]).transpose() # leaves matrix non-contigous and subsequent FFTs are 10x slower! X = np.array([x[i:i+Np] for i in range(N)]) # Staggered rows. # Step 2. N (row) FFTs of Np points. A = signal.get_window(('chebwin', 96), Np) # Tapering window, 1 x Np vec. XA= X*A XAT= fft.fftshift(fft.fft(X*A), axes=1) # Row FFTs, Eqn 27. # Step 3. Phase shifts. k = (np.arange(Np) - Np/2)/Np # 1 x Np. n = np.arange(N).reshape((N,1)) # N x 1. E = np.exp(-2j*pi*k*n) # N x Np matrix, Eqn 29. # Step 4. Np strip (column) FFTs of N points. Xg = XAT*E i = Np//2 # Start index of signal centered in N+Np samples. sig = x[i:i+N].reshape((N,1)) # Eqn 30. if not conj: sig = sig.conjugate() # Eqn 30. Xg *= sig/N # Eqn 31. '/N' is rect smoothing window of ampl 1/N. Sx = fft.fftshift(fft.fft(Xg, axis=0), axes=0) # Col FFTs, Eqn 32. if False: # XXX Unsure how useful this is... # Resolution of result. dt = N # Delta T, number of points in time domain signal. df = 1/Np # Freq resolution depends on tapering window bandwidth. da = 1/dt # Cycle freq resoln deps on points processed. dtdf = dt*df if showMults: # From [2], p.34. L = 1 m = 2*Np*((6*N/L + 4*N) + (2*N/L + 2*N)*log2(N)) print('SSCA multiplications: %.1f (x1e6)' % (m/1e6)) q = (np.arange(N) - N/2).reshape((N,1)) # N x 1 column vector. k = np.arange(Np) - Np/2 # 1 x Np row vector. f = (k/Np - q/N)/2 # N x Np matrix. alpha = k/Np + q/N # N x Np matrix. # return Sx, df, da, dtdf # SSCA SCF complex matrix, N x Np. return f, alpha, Sx # SSCA SCF complex matrix, N x Np. def sscaSc(x, sscaScf, conj=False, M=64, taper=False): '''Convert SSCA spectral correlation function (SCF) output to spectral coherances (SC). Inputs: x (complex[]) : 1d signal whose PSD is used in SC calculations. sscaScf (complex[]) - N x Np matrix if SCF value from SSCA. M (int) : SCF TSM block size, default 64. Output: (complex[][] : N x Np matrix of spectral coherance values corresponding to each element in sscaScf[][]. ''' N, Np = sscaScf.shape # Prepare roll up/down values. # # Calculate shift values, f +/- alpha/2. # q = {0,1,...,N-1} - N/2 k = {0,1,...,Np-1} - Np/2. # f = (k/Np - q/N)/2 alpha = k/Np + q/N # dn = f + alpha/2 = k/Np up = f - alpha/2 = -q/N q = np.arange(N) - N/2 # 1 x N, to be reshaped into column later. k = np.arange(Np) - Np/2 # 1 x Np row vector. dn = k/Np # 1 x Np, f + alpha/2. up = q/N if conj else -q/N # 1 x N, +/-(alpha/2 - f). # Generate PSD for coherence denominator. i = Np//2 # Start index of signal. sig = x[i:i+N] # 1 x N, N samples centered in N+Np samples. psd = scfTsm(sig, M, alpha=0, taper=taper) fCoarse = fft.fftshift(fft.fftfreq(M)) # Freqs in PSD. fn = interp1d(fCoarse, psd, fill_value=(psd[0],psd[-1]), bounds_error=False, kind='nearest') dnI = fn(dn) # 1 x Np row. upI = fn(up).reshape((N, 1)) # N x 1 col. denom = sqrt(dnI*upI) # N x Np matrix, z = np.where(denom == 0) # Div by 0 (or nearly) locations. sscaScf[z] = 0 # Return 0 coherance for divisons by 0. denom[z] = 1 return sscaScf/denom # Spectral coherences, N x Np matrix. def W_to_dB(val_W): return 10*log10(val_W) def zeroPad(x, n): '''Zero pad x until its length is an even multiple of n. Inputs: x (numpy.array) : signal to be zero padded. n (int) : x will be zero padded to be a multiple of n. Output: (numpy array) : copy of x, zero padded as necessary. ''' xz = x.copy() if x.size % n != 0: pad = n - (x.size % n) xz.resize(x.size + pad) return xz def zshift(a, n): '''Right shift array by n elements and zero fill. For example, >>> a = np.array([1,2,3,4,5,6,7]) >>> zshift(a, 3) >>> array([0, 0, 0, 1, 2, 3, 4]) >>> zshift(a, -2) >>> array([3, 4, 5, 6, 7, 0, 0]) Inputs: a (np.array[]) - array to roll right and zero fill. n (int) - number of positions to roll array to right. Output: (np.array[]) - rolled and zero filled array. ''' if len(a.shape) == 1: b = np.roll(a, n) if n >= 0: b[:n] = 0 else: b[n:] = 0 return b elif len(a.shape) == 2: b = np.roll(a, n, axis=1) if n >= 0: b[:, :n] = 0 else: b[:, n:] = 0 return b else: print('zshift: can\'t shift %d dimension array.') return []