import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import chi2, norm
def plot_chi2_normal_approx(ns, sample_size=10000):
x = np.linspace(-4, 4, 1000)
true_pdf = norm.pdf(x)
fig, axes = plt.subplots(1, len(ns), figsize=(15, 4), sharey=True)
for i, n in enumerate(ns):
chi_samples = chi2.rvs(df=n, size=sample_size)
normalized = (chi_samples - n) / np.sqrt(2 * n)
axes[i].hist(normalized, bins=50, density=True, alpha=0.6, color='skyblue', label='Normalized $\\chi^2$')
axes[i].plot(x, true_pdf, 'r--', label='Standard Normal PDF')
axes[i].set_title(f'n = {n}')
axes[i].legend()
axes[i].grid(True)
fig.suptitle(r'As $n \to \infty$, $(X - n)/\sqrt{2n} \sim \mathcal{N}(0,1)$', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
plot_chi2_normal_approx([5, 30, 100])