import numpy as np
import matplotlib.pyplot as plt
from dtk.process import power_spectrum

# sampling parameters
N = 64   # signal period
f_s = 10  # sample rate
T = N/f_s

t = np.arange(0, T, 1/f_s)

# rectangle test signal
A = 3  # amplitude
tau = 0.2*T # "on"-time

x = np.zeros_like(t)
x[0:int(tau*f_s)] = A

# power spectrum
freq, amp = power_spectrum(x, f_s)

# check Parseval's theorem
energy_time = np.mean(np.abs(x)**2)
energy_freq = np.sum(amp)

print(f"Mean power in time domain: {energy_time:.6f}")
print(f"Mean power in frequency domain: {energy_freq:.6f}")

# plot
fig, ax = plt.subplots(2, 1, layout="constrained")
ax[0].stem(t, x)
ax[0].set_xlabel("$t$ in s")
ax[0].set_ylabel("$x(t)$")
ax[1].stem(freq,amp)
ax[1].set_xlabel("$f$ in Hz")
ax[1].set_ylabel("$|X(f)|^2$")
plt.suptitle(f"Sample rate: {f_s} Hz, Signal period: {T} s")