import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.optimize import minimize
import matplotlib.pyplot as plt

# =========================
# Probit log-likelihood
# =========================
def loglik(beta, X, y):
    z = X @ beta
    p = norm.cdf(z)
    eps = 1e-8
    return -np.sum(y*np.log(p+eps) + (1-y)*np.log(1-p+eps))

# =========================
# Data generation
# =========================
def simulate_data(n, beta_true, seed=None):
    rng = np.random.default_rng(seed)

    x = rng.normal(size=n)
    X = np.column_stack([np.ones(n), x])

    eps = rng.normal(size=n)
    y_star = X @ beta_true + eps
    y = (y_star > 0).astype(int)

    return X, y

# =========================
# Monte Carlo
# =========================
def monte_carlo(n_list=(100, 500, 2000), reps=200):
    beta_true = np.array([1.0, 1.0])
    results = []

    rng = np.random.default_rng(1234)

    for n in n_list:
        for r in range(reps):
            seed = rng.integers(0, 10**9)
            X, y = simulate_data(n, beta_true, seed)

            res = minimize(loglik,
                           x0=np.zeros(2),
                           args=(X, y),
                           method="BFGS")

            beta_hat = res.x

            results.append({
                "n": n,
                "beta0_hat": beta_hat[0],
                "beta1_hat": beta_hat[1],
                "error": np.linalg.norm(beta_hat - beta_true)
            })

    return pd.DataFrame(results)

# =========================
# Run experiment
# =========================
df = monte_carlo()

# =========================
# Summary
# =========================
summary = df.groupby("n").agg(
    mean_b0=("beta0_hat", "mean"),
    mean_b1=("beta1_hat", "mean"),
    std_b0=("beta0_hat", "std"),
    std_b1=("beta1_hat", "std"),
    mean_error=("error", "mean")
)

print(summary)

# =========================
# Plots
# =========================

# 1. Consistency
plt.figure()
for n in df["n"].unique():
    sub = df[df["n"] == n]
    plt.hist(sub["beta1_hat"], bins=30, alpha=0.4, label=f"n={n}")
plt.legend()
plt.title("Distribution of beta1 estimates")
plt.show()

# 2. Error vs n
plt.figure()
err = df.groupby("n")["error"].mean()
plt.plot(err.index, err.values, marker="o")
plt.title("Mean estimation error vs n")
plt.xlabel("n")
plt.ylabel("Error")
plt.grid()
plt.show()