import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import minimize

# =========================
# Probit log-likelihood
# =========================
def loglik_probit(beta, X, y):
    z = X @ beta
    p = norm.cdf(z)
    eps = 1e-8
    return -np.sum(y*np.log(p+eps) + (1-y)*np.log(1-p+eps))

# =========================
# Data generation (MISSPECIFIED)
# =========================
def simulate_data_t(n, beta_true, df=3, seed=None):
    rng = np.random.default_rng(seed)

    x = rng.normal(size=n)
    X = np.column_stack([np.ones(n), x])

    # Student-t errors (heavy tails)
    eps = rng.standard_t(df, size=n)

    y_star = X @ beta_true + eps
    y = (y_star > 0).astype(int)

    return X, y

# =========================
# Monte Carlo
# =========================
def monte_carlo_misspec(
    n_list=(100, 500, 2000),
    reps=200,
    beta_true=np.array([1.0, 1.0]),
    df=3,
    random_seed=123
):
    rng = np.random.default_rng(random_seed)
    rows = []

    for n in n_list:
        for r in range(reps):
            seed = rng.integers(0, 10**9)

            X, y = simulate_data_t(n, beta_true, df=df, seed=seed)

            res = minimize(loglik_probit,
                           x0=np.zeros(2),
                           args=(X, y),
                           method="BFGS")

            beta_hat = res.x

            rows.append({
                "n": n,
                "beta0_hat": beta_hat[0],
                "beta1_hat": beta_hat[1],
                "error_true": np.linalg.norm(beta_hat - beta_true)
            })

    return pd.DataFrame(rows)

# =========================
# Run experiment
# =========================
df = monte_carlo_misspec()

# =========================
# Summary
# =========================
summary = df.groupby("n").agg(
    mean_b0=("beta0_hat", "mean"),
    mean_b1=("beta1_hat", "mean"),
    std_b0=("beta0_hat", "std"),
    std_b1=("beta1_hat", "std"),
    mean_error_true=("error_true", "mean")
)

print("\nMisspecification results (t-errors, df=3):\n")
print(summary)

# =========================
# Plots
# =========================

# 1. Distribution of estimates
plt.figure(figsize=(7,5))
for n in df["n"].unique():
    sub = df[df["n"] == n]
    plt.hist(sub["beta1_hat"], bins=30, alpha=0.4, label=f"n={n}")

plt.legend()
plt.title("Probit estimates under misspecification (t-errors)")
plt.xlabel("beta1_hat")
plt.show()

# 2. Error vs n
plt.figure(figsize=(7,5))
err = df.groupby("n")["mean_error_true"].mean()
plt.plot(err.index, err.values, marker="o")
plt.title("Error vs n (misspecified model)")
plt.xlabel("n")
plt.ylabel("distance from true beta")
plt.grid()
plt.show()