from pyestimate.estimators import pc_ar_estimator
import numpy as np
import matplotlib.pyplot as plt
N = 40 # number of samples
f = [0.12345, 0.10345, 0.05234, 0.05, 0.0351] # frequencies to be estimated
A = [1.23456, 1.34567, 1.0, 1.0, 1.0] # amplitude to be estimated
phi = [0.0, np.pi/4, 0.0, 0.0, 0.0] # phase to be estimated
p = len(A) # number of sinusoids
sigma = 0.01 # standard deviation of WGN
n = np.arange(N)
s = 0.0
for i in range(p):
    s += A[i] * np.cos(2*np.pi*f[i]*n+phi[i]) # original signal
w = np.random.default_rng(seed=0).normal(scale=sigma, size=N) # white gaussian noise
x = s+w # input signal for estimation: sine + noise
#
# Estimate sinusoid parameters
#
A_hat, f_hat, phi_hat = pc_ar_estimator(x, p) # parameters estimation
#
# Reconstruct original signal from estimated parameters
#
s_hat = 0.0
for i in range(p):
    s_hat += A_hat[i] * np.cos(2*np.pi*f_hat[i]*n+phi_hat[i]) # estimated signal
#
# Plot the original signal, the input signal corrupted with noise and the reconstructed signal
#
plt.plot(n, s, linewidth=3.0, label='original signal')
plt.plot(n, x, label='corrupted signal')
plt.plot(n, s_hat, 'k--', label='estimated signal')
plt.xlabel('$n$')
plt.ylabel('$x[n]$')
plt.title('Sum of sinusoids: frequencies, amplitudes and phases estimation in WGN')
plt.legend()
plt.grid()
plt.show()
