Source code for millipede.normal

import math
from types import SimpleNamespace

import torch
from torch import einsum, matmul, sigmoid
from torch.distributions import Beta, Categorical, Gamma
from torch.linalg import norm
from torch.linalg import solve_triangular as trisolve

from .sampler import MCMCSampler
from .util import (
    arange_complement,
    get_loo_inverses,
    leave_one_out,
    leave_one_out_off_diagonal,
    safe_cholesky,
    sample_active_subset,
)


[docs]class NormalLikelihoodSampler(MCMCSampler): r""" MCMC sampler for Bayesian variable selection for a linear model with a Normal likelihood. The likelihood variance is controlled by a Inverse Gamma prior. This class supports continuous-valued responses. The details of the available models in `NormalLikelihoodSampler` are as follows. The covariates :math:`X` and responses :math:`Y` are defined as follows: .. math:: X \in \mathbb{R}^{N \times P} \qquad \qquad Y \in \mathbb{R}^{N} The inclusion of each covariate is governed by a Bernoulli random variable :math:`\gamma_p`. In particular :math:`\gamma_p = 0` corresponds to exclusion and :math:`\gamma_p = 1` corresponds to inclusion. The prior probability of inclusion is governed by :math:`h` or alternatively :math:`S`: .. math:: h \in [0, 1] \qquad \rm{with} \qquad S \equiv hP Alternatively, if :math:`h` is not known a priori we can put a prior on :math:`h`: .. math:: h \sim {\rm Beta}(\alpha, \beta) \qquad \rm{with} \qquad \alpha > 0 \;\;\;\; \beta > 0 Putting this together, the model specification for an isotopric prior (with an intercept :math:`\beta_0` included) is as follows: .. math:: &\gamma_p \sim \rm{Bernoulli}(h) \qquad \rm{for} \qquad p=1,2,...,P &\sigma^2 \sim \rm{InverseGamma}(\nu_0 / 2, \nu_0 \lambda_0 / 2) &\beta_0 \sim \rm{Normal}(0, \sigma^2\tau_\rm{intercept}^{-1}) &\beta_\gamma \sim \rm{Normal}(0, \sigma^2 \tau^{-1} \mathbb{1}_\gamma) &Y_n \sim \rm{Normal}(\beta_0 + X_{n, \gamma} \cdot \beta_\gamma, \sigma^2) Note that the dimension of :math:`\beta_\gamma` depends on the number of covariates included in a particular model (i.e. on the number of non-zero entries in :math:`\gamma`). For a gprior the prior over the coefficients is instead specified as follows: .. math:: \beta_{\gamma} \sim \rm{Normal}(0, c \sigma^2 (X_\gamma^{\rm{T}} X_\gamma)^{-1}) Usage of this class is only recommended for advanced users. For most users it should suffice to use :class:`NormalLikelihoodVariableSelector`. :param tensor X: A N x P `torch.Tensor` of covariates. :param tensor Y: A N-dimensional `torch.Tensor` of continuous responses. :param tensor X_assumed: A N x P' `torch.Tensor` of covariates that are always assumed to be part of the model. Defaults to `None`. :param tensor sigma_scale_factor: A N-dimensional `torch.Tensor` of positive scale factors that are used to scale the standard deviation of the Normal likelihood for each datapoint. For example, specifying 2.0 for a particular datapoint results in :math:`\sigma \rightarrow 2 \times \sigma`. Defaults to `None`. :param S: Controls the expected number of covariates to include in the model a priori. Defaults to 5.0. To specify covariate-level prior inclusion probabilities provide a P-dimensional `torch.Tensor` of the form `(h_1, ..., h_P)`. If a tuple of positive floats `(alpha, beta)` is provided, the a priori inclusion probability is a latent variable governed by the corresponding Beta prior so that the sparsity level is inferred from the data. Note that for a given choice of `alpha` and `beta` the expected number of covariates to include in the model a priori is given by :math:`\frac{\alpha}{\alpha + \beta} \times P`. Also note that the mean number of covariates in the posterior can vary significantly from prior expectations, since the posterior is in effect a compromise between the prior and the observed data. :param str prior: One of the two supported priors for the coefficients: 'isotropic' or 'gprior'. Defaults to 'isotropic'. :param bool include_intercept: Whether to include an intercept term. If included the intercept term is is included in all models so that the corresponding coefficient does not have a PIP. :param float tau: Controls the precision of the coefficients in the isotropic prior. Defaults to 0.01. :param float tau_intercept: Controls the precision of the intercept in the isotropic prior. Defaults to 1.0e-4. :param float c: Controls the precision of the coefficients in the gprior. Defaults to 100.0. :param float nu0: Controls the prior over the precision in the Normal likelihood. Defaults to 0.0. :param float lambda0: Controls the prior over the precision in the Normal likelihood. Defaults to 0.0. :param float explore: This hyperparameter controls how greedy the MCMC algorithm is. Defaults to 5.0. :param bool precompute_XX: Whether the matrix X^t @ X should be pre-computed. Defaults to False. Note that setting this to True may result in out-of-memory errors for sufficiently large covariate matrices. :param bool verbose_constructor: Whether the class constructor should print some information to stdout upon initialization. :param float xi_target: This hyperparameter controls how often :math:`h` MCMC updates are made if :math:`h` is a latent variable. Defaults to 0.2. :param int subset_size: If `subset_size` is not None `subset_size` controls the amount of computational resources to use in Subset wTGS. Otherwise if `subset_size` is None vanilla wTGS is used. This argument is intended to be used for datasets with a very large number of covariates (e.g. tens of thousands or more). A typical value might be ~5-10% of the total number of covariates; smaller values result in more MCMC iterations per second but may lead to high variance PIP estimates. Defaults to None. :param int anchor_size: If `subset_size` is not None `anchor_size` controls how greedy Subset wTGS is. If `anchor_size` is None it defaults to half of `subset_size`. For expert users only. Defaults to None. """ def __init__(self, X, Y, X_assumed=None, sigma_scale_factor=None, S=5.0, prior="isotropic", include_intercept=True, tau=0.01, tau_intercept=1.0e-4, c=100.0, nu0=0.0, lambda0=0.0, explore=5, precompute_XX=False, compute_betas=False, verbose_constructor=True, xi_target=0.2, subset_size=None, anchor_size=None): assert prior in ['isotropic', 'gprior'] self.N, self.P = X.shape assert (self.N,) == Y.shape assert X.dtype == Y.dtype assert X.device == Y.device if subset_size is not None and (subset_size <= 1 or subset_size >= self.P): raise ValueError("If subset_size is not None must be strictly between 1 and P, the number of covariates.") self.subset_size = subset_size if anchor_size is not None: if subset_size is None: raise ValueError("The anchor_size argument should only be used if subset_size is not None.") if anchor_size < 1 or anchor_size >= subset_size: raise ValueError("anchor_size should be strictly between 0 and subset_size.") if X_assumed is not None: assert X.dtype == X_assumed.dtype assert X.device == X_assumed.device if X.size(0) != X_assumed.size(0): raise ValueError("X and X_assumed must have the same number of rows.") if sigma_scale_factor is not None: assert sigma_scale_factor.dtype == X.dtype assert sigma_scale_factor.device == X.device if sigma_scale_factor.shape != (self.N,): raise ValueError("sigma_scale_factor must be a N-dimensional tensor.") if sigma_scale_factor.min().item() <= 0.0: raise ValueError("All entries in sigma_scale_factor must be positive.") if prior != "isotropic": raise ValueError("sigma_scale_factor can only be used in conjuction with an isotropic prior.") self.device = Y.device self.dtype = Y.dtype self.prior = prior self.X = X self.Y = Y self.c = c if prior == 'gprior' else 0.0 self.tau = tau if prior == 'isotropic' else 0.0 if prior == 'isotropic': self.tau_intercept = tau_intercept if include_intercept else tau else: self.tau_intercept = 0.0 if X_assumed is not None: assert X_assumed.size(-1) > 0 self.X = torch.cat([self.X, X_assumed], dim=-1) if include_intercept: self.X = torch.cat([self.X, X.new_ones(X.size(0), 1)], dim=-1) S = S if not isinstance(S, int) else float(S) if isinstance(S, float): if S >= self.P or S <= 0: raise ValueError("S must satisfy 0 < S < P or must be a tuple or tensor.") elif isinstance(S, tuple): if len(S) != 2 or not isinstance(S[0], float) or not isinstance(S[1], float) or S[0] <= 0.0 or S[1] <= 0.0: raise ValueError("If S is a tuple it must be a tuple of two positive floats (alpha, beta).") elif isinstance(S, torch.Tensor): if S.shape != (self.P,) or (S >= 1.0).any().item() or (S <= 0.0).any().item(): raise ValueError("If S is a tensor it must be P-dimensional and all elements must be strictly" + " contained in (0, 1).") else: raise ValueError("S must be a float, tuple or tensor.") if prior == 'gprior' and self.c <= 0.0: raise ValueError("c must satisfy c > 0.0") if prior == 'isotropic' and self.tau <= 0.0: raise ValueError("tau must satisfy tau > 0.0") if explore <= 0.0: raise ValueError("explore must satisfy explore > 0.0") if nu0 < 0.0: raise ValueError("nu0 must satisfy nu0 >= 0.0") if lambda0 < 0.0: raise ValueError("lambda0 must satisfy lambda0 >= 0.0") if xi_target <= 0.0 or xi_target >= 1.0: raise ValueError("xi_target must be in the interval (0, 1).") if sigma_scale_factor is not None: self.X = self.X.clone() / sigma_scale_factor.unsqueeze(-1) Y_scaled = Y / sigma_scale_factor self.YY = Y_scaled.pow(2.0).sum() + nu0 * lambda0 self.Z = einsum("np,n->p", self.X, Y_scaled) else: self.YY = Y.pow(2.0).sum() + nu0 * lambda0 self.Z = einsum("np,n->p", self.X, Y) if isinstance(S, float): self.h = S / self.P self.xi = torch.tensor([0.0], device=self.device) self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) elif isinstance(S, tuple): self.h_alpha, self.h_beta = S self.h = self.h_alpha / (self.h_alpha + self.h_beta) self.xi = torch.tensor([5.0], device=self.device) self.xi_target = xi_target self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) else: self.h = S self.xi = torch.tensor([0.0], device=self.device) self.log_h_ratio = S.log() - torch.log1p(-S) if prior == "gprior": self.c_one_c = self.c / (1.0 + self.c) self.c_one_c_sqrt = math.sqrt(self.c_one_c) self.log_one_c_sqrt = 0.5 * math.log(1.0 + self.c) if self.subset_size is not None: self.anchor_size = subset_size // 2 if anchor_size is None else anchor_size self.pi = X.new_ones(self.P) * self.h if isinstance(S, (float, tuple)) else self.h self.total_weight = 0.0 self.comb_factor = (self.subset_size - self.anchor_size) / (self.P - self.anchor_size) else: self.comb_factor = 1.0 self.explore = explore / self.P self.N_nu0 = self.N + nu0 self.compute_betas = compute_betas self.include_intercept = include_intercept if include_intercept and X_assumed is None: self.assumed_covariates = torch.tensor([self.P], device=self.device, dtype=torch.int64) elif include_intercept and X_assumed is not None: self.assumed_covariates = torch.arange(self.P, self.P + X_assumed.size(-1) + 1, device=self.device, dtype=torch.int64) elif not include_intercept and X_assumed is not None: self.assumed_covariates = torch.arange(self.P, self.P + X_assumed.size(-1), device=self.device, dtype=torch.int64) else: self.assumed_covariates = None if precompute_XX: self.XX = self.X.t() @ self.X self.XX_diag = self.XX.diagonal() else: self.XX = None self.Pa = 0 if self.assumed_covariates is None else self.assumed_covariates.size(-1) self.epsilon = 1.0e3 * torch.finfo(Y.dtype).tiny if verbose_constructor: s2 = " = ({}, {}, {:.1f}, {:.3f}, {})" if not isinstance(S, tuple) \ else " = ({}, {}, ({:.1f}, {:.1f}), {:.3f}, {})" if isinstance(S, float): S = (S,) elif isinstance(S, torch.Tensor): S = (S.min().item(), S.max().item()) if self.prior == 'isotropic': s1 = "Initialized NormalLikelihoodSampler with isotropic prior and (N, P, S, tau, subset_size)" print((s1 + s2).format(self.N, self.P, *S, self.tau, self.subset_size)) else: s1 = "Initialized NormalLikelihoodSampler with gprior and (N, P, S, c, subset_size)" print((s1 + s2).format(self.N, self.P, *S, self.c, self.subset_size)) def initialize_sample(self, seed=None): if seed is not None: torch.manual_seed(seed) sample = SimpleNamespace(gamma=torch.zeros(self.P, device=self.device).bool(), _active=torch.tensor([], device=self.device, dtype=torch.int64), _log_h_ratio=self.log_h_ratio) if self.Pa > 0: sample._activeb = self.assumed_covariates if self.subset_size is not None: Z_cent = einsum("np,n->p", self.X[:, :self.P], self.Y - self.Y.mean()) self._update_anchor(Z_cent.abs().argsort()[-self.anchor_size:]) sample._idx = torch.randint(self.P, (), device=self.device) sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) if hasattr(self, "h_alpha"): sample.h_alpha = torch.tensor(self.h_alpha, device=self.device) sample.h_beta = torch.tensor(self.h_beta, device=self.device) sample = self._compute_probs(sample) return sample def _update_anchor(self, anchor): self.anchor_subset = anchor self.anchor_subset_set = set(anchor.data.cpu().numpy().tolist()) self.anchor_complement = arange_complement(self.P, anchor) def _compute_add_prob(self, sample, return_log_odds=False): active = sample._active activeb = sample._activeb if self.Pa > 0 else sample._active if self.subset_size is not None: inactive = torch.zeros(self.P, device=self.device, dtype=torch.bool) inactive[sample._active_subset] = ~(sample.gamma[sample._active_subset]) inactive = torch.nonzero(inactive).squeeze(-1) else: inactive = torch.nonzero(~sample.gamma).squeeze(-1) num_active = active.size(-1) assert num_active < self.P, "The MCMC sampler has been driven into a regime where " +\ "all covariates have been selected. Are you sure you have chosen a reasonable prior? " +\ "Are you sure there is signal in your data?" X_k = self.X[:, inactive] Z_k = self.Z[inactive] if self.XX is None: XX_k = norm(X_k, dim=0).pow(2.0) else: XX_k = self.XX_diag[inactive] if self.Pa > 0 or num_active > 0: X_activeb = self.X[:, activeb] Z_active = self.Z[activeb] if self.XX is not None: XX_active = self.XX[activeb][:, activeb] else: XX_active = X_activeb.t() @ X_activeb if self.prior == 'isotropic': XX_active.diagonal(dim1=-2, dim2=-1).add_(self.tau) if self.include_intercept: XX_active[-1, -1].add_(self.tau_intercept - self.tau) L_active = safe_cholesky(XX_active) Zt_active = trisolve(L_active, Z_active.unsqueeze(-1), upper=False).squeeze(-1) Xt_active = trisolve(L_active, X_activeb.t(), upper=False).t() XtZt_active = einsum("np,p->n", Xt_active, Zt_active) if self.XX is None: G_k_inv = XX_k + self.tau - norm(einsum("ni,nk->ik", Xt_active, X_k), dim=0).pow(2.0) else: normsq = trisolve(L_active, self.XX[activeb][:, inactive], upper=False) G_k_inv = XX_k + self.tau - norm(normsq, dim=0).pow(2.0) W_k_sq = (einsum("np,n->p", X_k, XtZt_active) - Z_k).pow(2.0) / (G_k_inv + self.epsilon) Zt_active_sq = Zt_active.pow(2.0).sum() if self.prior == 'isotropic': log_det_inactive = -0.5 * G_k_inv.log() + 0.5 * math.log(self.tau) else: W_k_sq = Z_k.pow(2.0) / (XX_k + self.tau + self.epsilon) Zt_active_sq = 0.0 if self.prior == 'isotropic': log_det_inactive = -0.5 * torch.log1p(XX_k / self.tau) if self.compute_betas and (num_active > 0 or self.Pa > 0): beta_active = trisolve(L_active.t(), Zt_active.unsqueeze(-1), upper=True).squeeze(-1) sample.beta = self.Y.new_zeros(self.P + self.Pa) epsilon = torch.randn(activeb.size(-1), 1, device=self.device, dtype=self.dtype) if self.prior == 'gprior': sigma_beta = 0.5 * (self.YY - self.c_one_c * Zt_active_sq) sample.sigma = Gamma(0.5 * self.N_nu0, sigma_beta).sample().sqrt().reciprocal() sample.beta[activeb] = self.c_one_c * beta_active sample.beta[activeb] += self.c_one_c_sqrt * sample.sigma * \ trisolve(L_active, epsilon, upper=False).squeeze(-1) else: sigma_beta = 0.5 * (self.YY - Zt_active_sq) sample.sigma = Gamma(0.5 * self.N_nu0, sigma_beta).sample().sqrt().reciprocal() sample.beta[activeb] = beta_active sample.beta[activeb] += sample.sigma * trisolve(L_active, epsilon, upper=False).squeeze(-1) elif self.compute_betas and num_active == 0: sample.beta = self.Y.new_zeros(self.P + self.Pa) if num_active > 1: active_loo = leave_one_out(active) # I I-1 if self.Pa > 0: assumed_covariates = self.assumed_covariates.expand(num_active, -1) active_loob = torch.cat([active_loo, assumed_covariates], dim=-1) else: active_loob = active_loo Z_active_loo = self.Z[active_loob] F = torch.cholesky_inverse(L_active, upper=False) F_loo = get_loo_inverses(F) if self.Pa > 0: F_loo = F_loo[:-self.Pa] Zt_active_loo = matmul(F_loo, Z_active_loo.unsqueeze(-1)).squeeze(-1) Zt_active_loo_sq = einsum("ij,ij->i", Zt_active_loo, Z_active_loo) if self.prior == 'isotropic': X_I_X_k = leave_one_out_off_diagonal(XX_active).unsqueeze(-1) X_I_X_k = X_I_X_k if self.Pa == 0 else X_I_X_k[:-self.Pa] F_X_I_X_k = matmul(F_loo, X_I_X_k).squeeze(-1) XXFXX = einsum("ij,ij->i", X_I_X_k.squeeze(-1), F_X_I_X_k) XX_active_diag = XX_active.diag() if self.Pa == 0 else XX_active.diag()[:-self.Pa] G_k_inv = XX_active_diag - XXFXX log_det_active = -0.5 * G_k_inv.log() + 0.5 * math.log(self.tau) elif num_active == 1: if self.Pa == 0: Zt_active_loo_sq = 0.0 if self.prior == 'isotropic': log_det_active = -0.5 * torch.log1p(norm(self.X[:, active], dim=0).pow(2.0) / self.tau) else: if self.XX is None: XX_assumed = self.X[:, self.assumed_covariates] XX_assumed = XX_assumed.t() @ XX_assumed else: XX_assumed = self.XX[self.assumed_covariates][:, self.assumed_covariates] if self.prior == "isotropic": XX_assumed.diagonal(dim1=-2, dim2=-1).add_(self.tau) if self.include_intercept: XX_assumed[-1, -1].add_(self.tau_intercept - self.tau) L_assumed = safe_cholesky(XX_assumed) Zt_active_loo = trisolve(L_assumed, self.Z[self.assumed_covariates].unsqueeze(-1), upper=False).squeeze(-1) Zt_active_loo_sq = norm(Zt_active_loo, dim=0).pow(2.0) if self.prior == "isotropic": log_det_active = L_assumed.diagonal().log().sum() - L_active.diagonal().log().sum() log_det_active += 0.5 * math.log(self.tau) elif num_active == 0: Zt_active_loo_sq = 0.0 log_det_active = torch.tensor(0.0, device=self.device, dtype=self.dtype) log_h_ratio_active = sample._log_h_ratio[active] if isinstance(self.h, torch.Tensor) else sample._log_h_ratio log_h_ratio_inactive = sample._log_h_ratio[inactive] if isinstance(self.h, torch.Tensor) \ else sample._log_h_ratio if self.prior == 'gprior': log_S_ratio = -torch.log1p(-self.c_one_c * W_k_sq / (self.YY - self.c_one_c * Zt_active_sq)) log_odds_inactive = log_h_ratio_inactive - self.log_one_c_sqrt + 0.5 * self.N_nu0 * log_S_ratio log_S_ratio = torch.log(self.YY - self.c_one_c * Zt_active_loo_sq) -\ torch.log(self.YY - self.c_one_c * Zt_active_sq) log_odds_active = log_h_ratio_active - self.log_one_c_sqrt + 0.5 * self.N_nu0 * log_S_ratio elif self.prior == 'isotropic': log_S_ratio = -torch.log1p(- W_k_sq / (self.YY - Zt_active_sq)) log_odds_inactive = log_h_ratio_inactive + log_det_inactive + 0.5 * self.N_nu0 * log_S_ratio log_S_ratio = (self.YY - Zt_active_loo_sq).log() - (self.YY - Zt_active_sq).log() log_odds_active = log_h_ratio_active + log_det_active + 0.5 * self.N_nu0 * log_S_ratio log_odds = self.Y.new_full((self.P,), -torch.inf) log_odds[inactive] = log_odds_inactive log_odds[active] = log_odds_active return log_odds def _compute_probs(self, sample): sample._add_prob = sigmoid(self._compute_add_prob(sample)) gamma = sample.gamma.type_as(sample._add_prob) prob_gamma_i = gamma * sample._add_prob + (1.0 - gamma) * (1.0 - sample._add_prob) _i_prob = 0.5 * (sample._add_prob + self.explore) / (prob_gamma_i + self.epsilon) if self.subset_size is not None: _i_prob[self.anchor_subset] *= self.comb_factor i_prob = torch.zeros_like(_i_prob) i_prob[sample._active_subset] = _i_prob[sample._active_subset] sample.pip = sample.gamma.type_as(i_prob) sample.pip[sample._active_subset] = sample._add_prob[sample._active_subset] else: i_prob = _i_prob sample.pip = sample._add_prob if hasattr(self, 'h_alpha') and self.t <= self.T_burnin: # adapt xi xi_comb = self.xi * self.comb_factor self.xi += (self.xi_target - xi_comb / (xi_comb + i_prob.sum())) / math.sqrt(self.t + 1) self.xi.clamp_(min=0.01) sample._i_prob = torch.cat([self.xi * self.comb_factor, i_prob]) return sample def mcmc_move(self, sample): self.t += 1 sample._idx = Categorical(probs=sample._i_prob).sample() - 1 if sample._idx.item() >= 0: sample.gamma[sample._idx] = ~sample.gamma[sample._idx] sample._active = torch.nonzero(sample.gamma).squeeze(-1) if self.Pa > 0: sample._activeb = torch.cat([sample._active, self.assumed_covariates]) else: sample = self.sample_alpha_beta(sample) if self.subset_size is not None: sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) sample = self._compute_probs(sample) sample.weight = sample._i_prob.mean().reciprocal() if self.subset_size is not None and self.t <= self.T_burnin: self.pi = sample.weight * sample.pip + self.total_weight * self.pi self.total_weight += sample.weight self.pi /= self.total_weight if (self.t > 99 and self.t % 100 == 0) or self.t == self.T_burnin: self._update_anchor(self.pi.argsort()[-self.anchor_size:]) return sample def sample_alpha_beta(self, sample): num_active = sample._active.size(-1) num_inactive = self.P - num_active sample.h_alpha = torch.tensor(self.h_alpha + num_active, device=self.device) sample.h_beta = torch.tensor(self.h_beta + num_inactive, device=self.device) h = Beta(sample.h_alpha, sample.h_beta).sample().item() sample._log_h_ratio = math.log(h) - math.log(1.0 - h) return sample