Source code for spikelab.spikedata.stat_utils

"""Statistical utilities for SpikeLab.

Provides reusable statistical functions (regression, confidence intervals,
pairwise group comparisons, paired tests, omnibus tests) that can be used
independently of plotting.
"""

import numpy as np


[docs] def linear_regression(x, y, ci_level=0.95): """Compute ordinary least-squares linear regression with optional confidence interval. Parameters: x (np.ndarray): 1-D array of predictor values. y (np.ndarray): 1-D array of response values (same length as *x*). ci_level (float): Confidence level for the interval (default 0.95). Returns: result (dict): Dictionary with keys: - ``slope`` (float): Fitted slope. - ``intercept`` (float): Fitted intercept. - ``r_squared`` (float): Coefficient of determination. - ``x_fit`` (np.ndarray): Sorted x values for plotting the fit line. - ``y_fit`` (np.ndarray): Predicted y values along *x_fit*. - ``ci_lower`` (np.ndarray): Lower confidence bound along *x_fit*. - ``ci_upper`` (np.ndarray): Upper confidence bound along *x_fit*. Notes: - Uses pure numpy (no scipy/sklearn dependency). - NaN pairs are dropped automatically. """ x = np.asarray(x, dtype=float).ravel() y = np.asarray(y, dtype=float).ravel() if len(x) != len(y): raise ValueError("x and y must have the same length.") # Drop NaN pairs valid = ~(np.isnan(x) | np.isnan(y)) x = x[valid] y = y[valid] n = len(x) if n < 3: raise ValueError("Need at least 3 non-NaN data points for regression.") # OLS via normal equations x_mean = np.mean(x) y_mean = np.mean(y) ss_xx = np.sum((x - x_mean) ** 2) if ss_xx == 0: raise ValueError("All x values are identical; regression is undefined.") ss_xy = np.sum((x - x_mean) * (y - y_mean)) slope = ss_xy / ss_xx intercept = y_mean - slope * x_mean # Predictions and R² y_pred = slope * x + intercept ss_res = np.sum((y - y_pred) ** 2) ss_tot = np.sum((y - y_mean) ** 2) r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0 # Confidence interval (t-distribution approximation via normal for large n) # For small n we use a simple approximation; scipy is not required. se = np.sqrt(ss_res / (n - 2)) if n > 2 else 0.0 # Approximate t critical value using normal quantile (good for n > 10, # conservative for smaller n) alpha = 1.0 - ci_level # Rational approximation of the normal quantile (Abramowitz & Stegun 26.2.23) p = 1.0 - alpha / 2.0 t_val = _approx_normal_quantile(p) x_fit = np.sort(x) y_fit = slope * x_fit + intercept se_fit = se * np.sqrt(1.0 / n + (x_fit - x_mean) ** 2 / ss_xx) ci_lower = y_fit - t_val * se_fit ci_upper = y_fit + t_val * se_fit return { "slope": slope, "intercept": intercept, "r_squared": r_squared, "x_fit": x_fit, "y_fit": y_fit, "ci_lower": ci_lower, "ci_upper": ci_upper, }
def _approx_normal_quantile(p): """Approximate the standard normal quantile for *p* in (0.5, 1). Uses the rational approximation from Abramowitz & Stegun (26.2.23). Accurate to ~4.5e-4 for typical confidence levels. """ if p <= 0.5: raise ValueError("p must be > 0.5") t = np.sqrt(-2.0 * np.log(1.0 - p)) c0, c1, c2 = 2.515517, 0.802853, 0.010328 d1, d2, d3 = 1.432788, 0.189269, 0.001308 return t - (c0 + c1 * t + c2 * t**2) / (1.0 + d1 * t + d2 * t**2 + d3 * t**3)
[docs] def pairwise_tests( groups, test="welch_t", correction="bonferroni", alpha=0.05, labels=None, ): """Run pairwise statistical tests across groups with multiple-comparison correction. Parameters: groups (dict[str, np.ndarray] or list[np.ndarray]): Per-group data arrays. Dict keys are used as labels; for list input supply ``labels`` separately. test (str): Statistical test to use. ``"welch_t"`` (default) for Welch's unequal-variance t-test, ``"student_t"`` for Student's equal-variance t-test, ``"mann_whitney"`` for the Mann-Whitney U test. All require ``scipy``. correction (str or None): Multiple-comparison correction. ``"bonferroni"`` (default) or ``None`` for uncorrected p-values. alpha (float): Significance threshold applied after correction. Default 0.05. labels (list[str] or None): Group labels. Required for list input; ignored for dict input (keys are used). Returns: result (dict): Dictionary with keys: - ``pval_matrix`` (np.ndarray): (K, K) corrected p-values. Diagonal entries are NaN. - ``sig_matrix`` (np.ndarray): (K, K) boolean — True where corrected p < ``alpha``. - ``n_comparisons`` (int): Number of pairwise comparisons. - ``labels`` (list[str]): Ordered group labels. Notes: - Requires ``scipy`` (optional dependency). Raises ``ImportError`` with installation instructions if not available. """ try: from scipy import stats as sp_stats except ImportError as e: raise ImportError( "pairwise_tests requires 'scipy'. Install with: pip install scipy" ) from e # --- Normalise input -------------------------------------------------- if isinstance(groups, dict): ordered_labels = list(groups.keys()) data = [np.asarray(groups[k]).ravel() for k in ordered_labels] else: data = [np.asarray(a).ravel() for a in groups] if labels is not None: ordered_labels = list(labels) else: ordered_labels = [str(i) for i in range(len(data))] # Strip NaNs data = [d[~np.isnan(d)] for d in data] K = len(data) n_comparisons = K * (K - 1) // 2 pval_matrix = np.full((K, K), np.nan) # --- Select test function --------------------------------------------- if test == "welch_t": def _test(a, b): _, p = sp_stats.ttest_ind(a, b, equal_var=False) return p elif test == "student_t": def _test(a, b): _, p = sp_stats.ttest_ind(a, b, equal_var=True) return p elif test == "mann_whitney": def _test(a, b): _, p = sp_stats.mannwhitneyu(a, b, alternative="two-sided") return p else: raise ValueError( f"Unknown test '{test}'. Use 'welch_t', 'student_t', or 'mann_whitney'." ) # --- Run pairwise tests ----------------------------------------------- for i in range(K): for j in range(i + 1, K): p = _test(data[i], data[j]) if correction == "bonferroni": p = min(p * n_comparisons, 1.0) elif correction is not None: raise ValueError( f"Unknown correction '{correction}'. " "Use 'bonferroni' or None." ) pval_matrix[i, j] = p pval_matrix[j, i] = p sig_matrix = pval_matrix < alpha return { "pval_matrix": pval_matrix, "sig_matrix": sig_matrix, "n_comparisons": n_comparisons, "labels": ordered_labels, }
[docs] def paired_test( a, b, test="wilcoxon", alternative="two-sided", ): """Run a paired statistical test on two matched samples. Parameters: a (array-like): First sample (1-D). b (array-like): Second sample (1-D, same length as *a*). test (str): ``"wilcoxon"`` (default) for the Wilcoxon signed-rank test, or ``"paired_t"`` for a paired Student's t-test. alternative (str): ``"two-sided"`` (default), ``"less"``, or ``"greater"``. Passed directly to the underlying scipy test. Returns: result (dict): Dictionary with keys: - ``statistic`` (float): Test statistic (W for Wilcoxon, t for paired t). - ``p_value`` (float): p-value for the test. - ``n`` (int): Number of valid (non-NaN, non-zero-difference) pairs used. Notes: - NaN pairs (where either *a* or *b* is NaN) are dropped automatically. - Requires ``scipy`` (optional dependency). """ try: from scipy import stats as sp_stats except ImportError as e: raise ImportError( "paired_test requires 'scipy'. Install with: pip install scipy" ) from e a = np.asarray(a, dtype=float).ravel() b = np.asarray(b, dtype=float).ravel() if len(a) != len(b): raise ValueError( f"a and b must have the same length, got {len(a)} and {len(b)}." ) valid = ~(np.isnan(a) | np.isnan(b)) a = a[valid] b = b[valid] if len(a) < 1: raise ValueError("No valid (non-NaN) pairs to test.") if test == "wilcoxon": stat, p = sp_stats.wilcoxon(a, b, alternative=alternative) elif test == "paired_t": stat, p = sp_stats.ttest_rel(a, b, alternative=alternative) else: raise ValueError(f"Unknown test '{test}'. Use 'wilcoxon' or 'paired_t'.") return {"statistic": float(stat), "p_value": float(p), "n": len(a)}
[docs] def mixed_effects_compare( values, fixed_effects, random_effect, *, formula=None, response_label="value", alpha=0.05, ): """Fit a linear mixed-effects model with one random intercept. Wraps ``statsmodels.regression.mixed_linear_model.MixedLM``. Typical use: compare a per-observation metric across treatments while accounting for repeated measurements within a recording, unit, or subject. Parameters: values (array-like): 1-D response variable, one entry per observation. fixed_effects (dict[str, array-like]): Mapping of fixed-effect names to per-observation arrays. All arrays must have the same length as *values*. Categorical effects can be passed as string/object arrays — statsmodels will dummy-code them via the Patsy formula. random_effect (array-like): Categorical group labels (one per observation) used as the random intercept group. formula (str or None): Patsy formula for the fixed-effects part. When None (default), an additive formula ``"<response_label> ~ k1 + k2 + ..."`` is built from the keys of *fixed_effects*. Provide an explicit formula for interactions (e.g. ``"value ~ treatment * latency"``). response_label (str): Name used for the response column in the constructed DataFrame and in the auto-built formula (default ``"value"``). alpha (float): Significance threshold applied to coefficient p-values when building the ``significant`` flags. Default 0.05. Returns: result (dict): Dictionary with keys: - ``params`` (dict[str, float]): Estimated coefficients keyed by Patsy term name. - ``pvalues`` (dict[str, float]): Two-sided p-values per term. - ``conf_int`` (dict[str, tuple[float, float]]): Confidence intervals per term (using *alpha*). - ``significant`` (dict[str, bool]): True where the term's p-value is below *alpha*. - ``random_effect_variance`` (float): Estimated variance of the random intercept. - ``n_obs`` (int): Number of observations used. - ``n_groups`` (int): Number of random-effect levels. - ``converged`` (bool): Whether the optimizer converged. - ``summary`` (str): Full statsmodels summary text. - ``model``: The fitted ``MixedLMResults`` object. Notes: - Requires both ``pandas`` and ``statsmodels`` (optional dependencies). Raises ``ImportError`` with installation instructions when either is missing. - NaN observations (in *values*, any fixed effect, or the random effect column) are dropped before fitting. - Tries the lbfgs, powell, and bfgs optimizers in turn; near- singular designs commonly break lbfgs' gradient path. """ try: import pandas as pd except ImportError as e: raise ImportError( "mixed_effects_compare requires 'pandas'. " "Install with: pip install pandas" ) from e try: import statsmodels.formula.api as smf except ImportError as e: raise ImportError( "mixed_effects_compare requires 'statsmodels'. " "Install with: pip install statsmodels" ) from e values = np.asarray(values).ravel() n = len(values) if n == 0: raise ValueError("values must not be empty.") if not isinstance(fixed_effects, dict) or len(fixed_effects) == 0: raise ValueError( "fixed_effects must be a non-empty dict mapping name -> array." ) columns = {response_label: values} for name, arr in fixed_effects.items(): arr = np.asarray(arr).ravel() if len(arr) != n: raise ValueError( f"fixed_effects[{name!r}] has length {len(arr)} but values has {n}." ) columns[name] = arr random_arr = np.asarray(random_effect).ravel() if len(random_arr) != n: raise ValueError( f"random_effect has length {len(random_arr)} but values has {n}." ) columns["_random_group"] = random_arr df = pd.DataFrame(columns) df = df.dropna() if len(df) < 2: raise ValueError( "Need at least 2 valid (non-NaN) observations after dropping NaNs." ) if df["_random_group"].nunique() < 2: raise ValueError( "Need at least 2 distinct random-effect levels to fit a mixed model." ) if formula is None: formula = f"{response_label} ~ " + " + ".join(fixed_effects.keys()) model = smf.mixedlm(formula, data=df, groups=df["_random_group"]) fit = None last_err = None for method in ("lbfgs", "powell", "bfgs"): try: fit = model.fit(method=method) break except (np.linalg.LinAlgError, ValueError) as e: last_err = e continue if fit is None: raise RuntimeError( "MixedLM failed to converge with any optimizer " "(tried lbfgs, powell, bfgs). The design may be near-singular — " "check that fixed effects are not collinear and that random-effect " f"groups have enough observations. Last error: {last_err!r}" ) params = {k: float(v) for k, v in fit.params.items() if k != "Group Var"} pvalues = {k: float(v) for k, v in fit.pvalues.items() if k != "Group Var"} ci_df = fit.conf_int(alpha=alpha) conf_int = { k: (float(row[0]), float(row[1])) for k, row in ci_df.iterrows() if k != "Group Var" } significant = {k: bool(p < alpha) for k, p in pvalues.items()} re_variance = ( float(fit.cov_re.iloc[0, 0]) if hasattr(fit.cov_re, "iloc") else float(np.asarray(fit.cov_re).ravel()[0]) ) return { "params": params, "pvalues": pvalues, "conf_int": conf_int, "significant": significant, "random_effect_variance": re_variance, "n_obs": int(len(df)), "n_groups": int(df["_random_group"].nunique()), "converged": bool(fit.converged), "summary": str(fit.summary()), "model": fit, }
[docs] def omnibus_test( groups, test="anova", posthoc="tukey", labels=None, ): """Run an omnibus test across groups with optional post-hoc comparisons. Parameters: groups (dict[str, array-like] or list[array-like]): Per-group data. Dict keys are used as labels; for list input supply *labels* separately. test (str): ``"anova"`` (default) for one-way ANOVA (``scipy.stats.f_oneway``), or ``"kruskal"`` for the Kruskal-Wallis H test. posthoc (str or None): Post-hoc test to run when the omnibus test is significant. ``"tukey"`` (default) for Tukey HSD, ``"none"``/``None`` to skip post-hoc comparisons. labels (list[str] or None): Group labels for list input. Ignored for dict input (keys are used). Returns: result (dict): Dictionary with keys: - ``statistic`` (float): F-statistic (ANOVA) or H-statistic (Kruskal-Wallis). - ``p_value`` (float): Omnibus p-value. - ``n_groups`` (int): Number of groups. - ``group_ns`` (list[int]): Sample sizes per group. - ``labels`` (list[str]): Ordered group labels. - ``posthoc`` (list[dict] or None): Post-hoc comparison results when *posthoc* is not None. Each dict contains: ``"group_a"``, ``"group_b"``, ``"p_value"``, ``"significant"`` (at alpha=0.05). Notes: - NaN values are stripped from each group before testing. - Requires ``scipy`` (optional dependency). """ try: from scipy import stats as sp_stats except ImportError as e: raise ImportError( "omnibus_test requires 'scipy'. Install with: pip install scipy" ) from e if isinstance(groups, dict): ordered_labels = list(groups.keys()) data = [np.asarray(groups[k], dtype=float).ravel() for k in ordered_labels] else: data = [np.asarray(a, dtype=float).ravel() for a in groups] ordered_labels = ( list(labels) if labels is not None else [str(i) for i in range(len(data))] ) data = [d[~np.isnan(d)] for d in data] K = len(data) if K < 2: raise ValueError("Need at least 2 groups for an omnibus test.") if test == "anova": stat, p = sp_stats.f_oneway(*data) elif test == "kruskal": stat, p = sp_stats.kruskal(*data) else: raise ValueError(f"Unknown test '{test}'. Use 'anova' or 'kruskal'.") result = { "statistic": float(stat), "p_value": float(p), "n_groups": K, "group_ns": [len(d) for d in data], "labels": ordered_labels, "posthoc": None, } if posthoc is not None and posthoc != "none": if posthoc == "tukey": tukey = sp_stats.tukey_hsd(*data) posthoc_results = [] for i in range(K): for j in range(i + 1, K): pv = float(tukey.pvalue[i, j]) posthoc_results.append( { "group_a": ordered_labels[i], "group_b": ordered_labels[j], "p_value": pv, "significant": pv < 0.05, } ) result["posthoc"] = posthoc_results else: raise ValueError(f"Unknown posthoc '{posthoc}'. Use 'tukey' or None.") return result