import math
from types import SimpleNamespace

import torch
from torch import einsum, matmul, sigmoid
from torch.distributions import Beta, Categorical, Gamma
from torch.linalg import norm
from torch.linalg import solve_triangular as trisolve

from .sampler import MCMCSampler
from .util import (

[docs]class NormalLikelihoodSampler(MCMCSampler): r""" MCMC sampler for Bayesian variable selection for a linear model with a Normal likelihood. The likelihood variance is controlled by a Inverse Gamma prior. This class supports continuous-valued responses. The details of the available models in `NormalLikelihoodSampler` are as follows. The covariates :math:`X` and responses :math:`Y` are defined as follows: .. math:: X \in \mathbb{R}^{N \times P} \qquad \qquad Y \in \mathbb{R}^{N} The inclusion of each covariate is governed by a Bernoulli random variable :math:`\gamma_p`. In particular :math:`\gamma_p = 0` corresponds to exclusion and :math:`\gamma_p = 1` corresponds to inclusion. The prior probability of inclusion is governed by :math:`h` or alternatively :math:`S`: .. math:: h \in [0, 1] \qquad \rm{with} \qquad S \equiv hP Alternatively, if :math:`h` is not known a priori we can put a prior on :math:`h`: .. math:: h \sim {\rm Beta}(\alpha, \beta) \qquad \rm{with} \qquad \alpha > 0 \;\;\;\; \beta > 0 Putting this together, the model specification for an isotopric prior (with an intercept :math:`\beta_0` included) is as follows: .. math:: &\gamma_p \sim \rm{Bernoulli}(h) \qquad \rm{for} \qquad p=1,2,...,P &\sigma^2 \sim \rm{InverseGamma}(\nu_0 / 2, \nu_0 \lambda_0 / 2) &\beta_0 \sim \rm{Normal}(0, \sigma^2\tau_\rm{intercept}^{-1}) &\beta_\gamma \sim \rm{Normal}(0, \sigma^2 \tau^{-1} \mathbb{1}_\gamma) &Y_n \sim \rm{Normal}(\beta_0 + X_{n, \gamma} \cdot \beta_\gamma, \sigma^2) Note that the dimension of :math:`\beta_\gamma` depends on the number of covariates included in a particular model (i.e. on the number of non-zero entries in :math:`\gamma`). For a gprior the prior over the coefficients is instead specified as follows: .. math:: \beta_{\gamma} \sim \rm{Normal}(0, c \sigma^2 (X_\gamma^{\rm{T}} X_\gamma)^{-1}) Usage of this class is only recommended for advanced users. For most users it should suffice to use :class:`NormalLikelihoodVariableSelector`. :param tensor X: A N x P `torch.Tensor` of covariates. :param tensor Y: A N-dimensional `torch.Tensor` of continuous responses. :param tensor X_assumed: A N x P' `torch.Tensor` of covariates that are always assumed to be part of the model. Defaults to `None`. :param tensor sigma_scale_factor: A N-dimensional `torch.Tensor` of positive scale factors that are used to scale the standard deviation of the Normal likelihood for each datapoint. For example, specifying 2.0 for a particular datapoint results in :math:`\sigma \rightarrow 2 \times \sigma`. Defaults to `None`. :param S: Controls the expected number of covariates to include in the model a priori. Defaults to 5.0. To specify covariate-level prior inclusion probabilities provide a P-dimensional `torch.Tensor` of the form `(h_1, ..., h_P)`. If a tuple of positive floats `(alpha, beta)` is provided, the a priori inclusion probability is a latent variable governed by the corresponding Beta prior so that the sparsity level is inferred from the data. Note that for a given choice of `alpha` and `beta` the expected number of covariates to include in the model a priori is given by :math:`\frac{\alpha}{\alpha + \beta} \times P`. Also note that the mean number of covariates in the posterior can vary significantly from prior expectations, since the posterior is in effect a compromise between the prior and the observed data. :param str prior: One of the two supported priors for the coefficients: 'isotropic' or 'gprior'. Defaults to 'isotropic'. :param bool include_intercept: Whether to include an intercept term. If included the intercept term is is included in all models so that the corresponding coefficient does not have a PIP. :param float tau: Controls the precision of the coefficients in the isotropic prior. Defaults to 0.01. :param float tau_intercept: Controls the precision of the intercept in the isotropic prior. Defaults to 1.0e-4. :param float c: Controls the precision of the coefficients in the gprior. Defaults to 100.0. :param float nu0: Controls the prior over the precision in the Normal likelihood. Defaults to 0.0. :param float lambda0: Controls the prior over the precision in the Normal likelihood. Defaults to 0.0. :param float explore: This hyperparameter controls how greedy the MCMC algorithm is. Defaults to 5.0. :param bool precompute_XX: Whether the matrix X^t @ X should be pre-computed. Defaults to False. Note that setting this to True may result in out-of-memory errors for sufficiently large covariate matrices. :param bool verbose_constructor: Whether the class constructor should print some information to stdout upon initialization. :param float xi_target: This hyperparameter controls how often :math:`h` MCMC updates are made if :math:`h` is a latent variable. Defaults to 0.2. :param int subset_size: If `subset_size` is not None `subset_size` controls the amount of computational resources to use in Subset wTGS. Otherwise if `subset_size` is None vanilla wTGS is used. This argument is intended to be used for datasets with a very large number of covariates (e.g. tens of thousands or more). A typical value might be ~5-10% of the total number of covariates; smaller values result in more MCMC iterations per second but may lead to high variance PIP estimates. Defaults to None. :param int anchor_size: If `subset_size` is not None `anchor_size` controls how greedy Subset wTGS is. If `anchor_size` is None it defaults to half of `subset_size`. For expert users only. Defaults to None. """ def __init__(self, X, Y, X_assumed=None, sigma_scale_factor=None, S=5.0, prior="isotropic", include_intercept=True, tau=0.01, tau_intercept=1.0e-4, c=100.0, nu0=0.0, lambda0=0.0, explore=5, precompute_XX=False, compute_betas=False, verbose_constructor=True, xi_target=0.2, subset_size=None, anchor_size=None): assert prior in ['isotropic', 'gprior'] self.N, self.P = X.shape assert (self.N,) == Y.shape assert X.dtype == Y.dtype assert X.device == Y.device if subset_size is not None and (subset_size <= 1 or subset_size >= self.P): raise ValueError("If subset_size is not None must be strictly between 1 and P, the number of covariates.") self.subset_size = subset_size if anchor_size is not None: if subset_size is None: raise ValueError("The anchor_size argument should only be used if subset_size is not None.") if anchor_size < 1 or anchor_size >= subset_size: raise ValueError("anchor_size should be strictly between 0 and subset_size.") if X_assumed is not None: assert X.dtype == X_assumed.dtype assert X.device == X_assumed.device if X.size(0) != X_assumed.size(0): raise ValueError("X and X_assumed must have the same number of rows.") if sigma_scale_factor is not None: assert sigma_scale_factor.dtype == X.dtype assert sigma_scale_factor.device == X.device if sigma_scale_factor.shape != (self.N,): raise ValueError("sigma_scale_factor must be a N-dimensional tensor.") if sigma_scale_factor.min().item() <= 0.0: raise ValueError("All entries in sigma_scale_factor must be positive.") if prior != "isotropic": raise ValueError("sigma_scale_factor can only be used in conjuction with an isotropic prior.") self.device = Y.device self.dtype = Y.dtype self.prior = prior self.X = X self.Y = Y self.c = c if prior == 'gprior' else 0.0 self.tau = tau if prior == 'isotropic' else 0.0 if prior == 'isotropic': self.tau_intercept = tau_intercept if include_intercept else tau else: self.tau_intercept = 0.0 if X_assumed is not None: assert X_assumed.size(-1) > 0 self.X =[self.X, X_assumed], dim=-1) if include_intercept: self.X =[self.X, X.new_ones(X.size(0), 1)], dim=-1) S = S if not isinstance(S, int) else float(S) if isinstance(S, float): if S >= self.P or S <= 0: raise ValueError("S must satisfy 0 < S < P or must be a tuple or tensor.") elif isinstance(S, tuple): if len(S) != 2 or not isinstance(S[0], float) or not isinstance(S[1], float) or S[0] <= 0.0 or S[1] <= 0.0: raise ValueError("If S is a tuple it must be a tuple of two positive floats (alpha, beta).") elif isinstance(S, torch.Tensor): if S.shape != (self.P,) or (S >= 1.0).any().item() or (S <= 0.0).any().item(): raise ValueError("If S is a tensor it must be P-dimensional and all elements must be strictly" + " contained in (0, 1).") else: raise ValueError("S must be a float, tuple or tensor.") if prior == 'gprior' and self.c <= 0.0: raise ValueError("c must satisfy c > 0.0") if prior == 'isotropic' and self.tau <= 0.0: raise ValueError("tau must satisfy tau > 0.0") if explore <= 0.0: raise ValueError("explore must satisfy explore > 0.0") if nu0 < 0.0: raise ValueError("nu0 must satisfy nu0 >= 0.0") if lambda0 < 0.0: raise ValueError("lambda0 must satisfy lambda0 >= 0.0") if xi_target <= 0.0 or xi_target >= 1.0: raise ValueError("xi_target must be in the interval (0, 1).") if sigma_scale_factor is not None: self.X = self.X.clone() / sigma_scale_factor.unsqueeze(-1) Y_scaled = Y / sigma_scale_factor self.YY = Y_scaled.pow(2.0).sum() + nu0 * lambda0 self.Z = einsum("np,n->p", self.X, Y_scaled) else: self.YY = Y.pow(2.0).sum() + nu0 * lambda0 self.Z = einsum("np,n->p", self.X, Y) if isinstance(S, float): self.h = S / self.P self.xi = torch.tensor([0.0], device=self.device) self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) elif isinstance(S, tuple): self.h_alpha, self.h_beta = S self.h = self.h_alpha / (self.h_alpha + self.h_beta) self.xi = torch.tensor([5.0], device=self.device) self.xi_target = xi_target self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) else: self.h = S self.xi = torch.tensor([0.0], device=self.device) self.log_h_ratio = S.log() - torch.log1p(-S) if prior == "gprior": self.c_one_c = self.c / (1.0 + self.c) self.c_one_c_sqrt = math.sqrt(self.c_one_c) self.log_one_c_sqrt = 0.5 * math.log(1.0 + self.c) if self.subset_size is not None: self.anchor_size = subset_size // 2 if anchor_size is None else anchor_size self.pi = X.new_ones(self.P) * self.h if isinstance(S, (float, tuple)) else self.h self.total_weight = 0.0 self.comb_factor = (self.subset_size - self.anchor_size) / (self.P - self.anchor_size) else: self.comb_factor = 1.0 self.explore = explore / self.P self.N_nu0 = self.N + nu0 self.compute_betas = compute_betas self.include_intercept = include_intercept if include_intercept and X_assumed is None: self.assumed_covariates = torch.tensor([self.P], device=self.device, dtype=torch.int64) elif include_intercept and X_assumed is not None: self.assumed_covariates = torch.arange(self.P, self.P + X_assumed.size(-1) + 1, device=self.device, dtype=torch.int64) elif not include_intercept and X_assumed is not None: self.assumed_covariates = torch.arange(self.P, self.P + X_assumed.size(-1), device=self.device, dtype=torch.int64) else: self.assumed_covariates = None if precompute_XX: self.XX = self.X.t() @ self.X self.XX_diag = self.XX.diagonal() else: self.XX = None self.Pa = 0 if self.assumed_covariates is None else self.assumed_covariates.size(-1) self.epsilon = 1.0e3 * torch.finfo(Y.dtype).tiny if verbose_constructor: s2 = " = ({}, {}, {:.1f}, {:.3f}, {})" if not isinstance(S, tuple) \ else " = ({}, {}, ({:.1f}, {:.1f}), {:.3f}, {})" if isinstance(S, float): S = (S,) elif isinstance(S, torch.Tensor): S = (S.min().item(), S.max().item()) if self.prior == 'isotropic': s1 = "Initialized NormalLikelihoodSampler with isotropic prior and (N, P, S, tau, subset_size)" print((s1 + s2).format(self.N, self.P, *S, self.tau, self.subset_size)) else: s1 = "Initialized NormalLikelihoodSampler with gprior and (N, P, S, c, subset_size)" print((s1 + s2).format(self.N, self.P, *S, self.c, self.subset_size)) def initialize_sample(self, seed=None): if seed is not None: torch.manual_seed(seed) sample = SimpleNamespace(gamma=torch.zeros(self.P, device=self.device).bool(), _active=torch.tensor([], device=self.device, dtype=torch.int64), _log_h_ratio=self.log_h_ratio) if self.Pa > 0: sample._activeb = self.assumed_covariates if self.subset_size is not None: Z_cent = einsum("np,n->p", self.X[:, :self.P], self.Y - self.Y.mean()) self._update_anchor(Z_cent.abs().argsort()[-self.anchor_size:]) sample._idx = torch.randint(self.P, (), device=self.device) sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) if hasattr(self, "h_alpha"): sample.h_alpha = torch.tensor(self.h_alpha, device=self.device) sample.h_beta = torch.tensor(self.h_beta, device=self.device) sample = self._compute_probs(sample) return sample def _update_anchor(self, anchor): self.anchor_subset = anchor self.anchor_subset_set = set( self.anchor_complement = arange_complement(self.P, anchor) def _compute_add_prob(self, sample, return_log_odds=False): active = sample._active activeb = sample._activeb if self.Pa > 0 else sample._active if self.subset_size is not None: inactive = torch.zeros(self.P, device=self.device, dtype=torch.bool) inactive[sample._active_subset] = ~(sample.gamma[sample._active_subset]) inactive = torch.nonzero(inactive).squeeze(-1) else: inactive = torch.nonzero(~sample.gamma).squeeze(-1) num_active = active.size(-1) assert num_active < self.P, "The MCMC sampler has been driven into a regime where " +\ "all covariates have been selected. Are you sure you have chosen a reasonable prior? " +\ "Are you sure there is signal in your data?" X_k = self.X[:, inactive] Z_k = self.Z[inactive] if self.XX is None: XX_k = norm(X_k, dim=0).pow(2.0) else: XX_k = self.XX_diag[inactive] if self.Pa > 0 or num_active > 0: X_activeb = self.X[:, activeb] Z_active = self.Z[activeb] if self.XX is not None: XX_active = self.XX[activeb][:, activeb] else: XX_active = X_activeb.t() @ X_activeb if self.prior == 'isotropic': XX_active.diagonal(dim1=-2, dim2=-1).add_(self.tau) if self.include_intercept: XX_active[-1, -1].add_(self.tau_intercept - self.tau) L_active = safe_cholesky(XX_active) Zt_active = trisolve(L_active, Z_active.unsqueeze(-1), upper=False).squeeze(-1) Xt_active = trisolve(L_active, X_activeb.t(), upper=False).t() XtZt_active = einsum("np,p->n", Xt_active, Zt_active) if self.XX is None: G_k_inv = XX_k + self.tau - norm(einsum("ni,nk->ik", Xt_active, X_k), dim=0).pow(2.0) else: normsq = trisolve(L_active, self.XX[activeb][:, inactive], upper=False) G_k_inv = XX_k + self.tau - norm(normsq, dim=0).pow(2.0) W_k_sq = (einsum("np,n->p", X_k, XtZt_active) - Z_k).pow(2.0) / (G_k_inv + self.epsilon) Zt_active_sq = Zt_active.pow(2.0).sum() if self.prior == 'isotropic': log_det_inactive = -0.5 * G_k_inv.log() + 0.5 * math.log(self.tau) else: W_k_sq = Z_k.pow(2.0) / (XX_k + self.tau + self.epsilon) Zt_active_sq = 0.0 if self.prior == 'isotropic': log_det_inactive = -0.5 * torch.log1p(XX_k / self.tau) if self.compute_betas and (num_active > 0 or self.Pa > 0): beta_active = trisolve(L_active.t(), Zt_active.unsqueeze(-1), upper=True).squeeze(-1) sample.beta = self.Y.new_zeros(self.P + self.Pa) epsilon = torch.randn(activeb.size(-1), 1, device=self.device, dtype=self.dtype) if self.prior == 'gprior': sigma_beta = 0.5 * (self.YY - self.c_one_c * Zt_active_sq) sample.sigma = Gamma(0.5 * self.N_nu0, sigma_beta).sample().sqrt().reciprocal() sample.beta[activeb] = self.c_one_c * beta_active sample.beta[activeb] += self.c_one_c_sqrt * sample.sigma * \ trisolve(L_active, epsilon, upper=False).squeeze(-1) else: sigma_beta = 0.5 * (self.YY - Zt_active_sq) sample.sigma = Gamma(0.5 * self.N_nu0, sigma_beta).sample().sqrt().reciprocal() sample.beta[activeb] = beta_active sample.beta[activeb] += sample.sigma * trisolve(L_active, epsilon, upper=False).squeeze(-1) elif self.compute_betas and num_active == 0: sample.beta = self.Y.new_zeros(self.P + self.Pa) if num_active > 1: active_loo = leave_one_out(active) # I I-1 if self.Pa > 0: assumed_covariates = self.assumed_covariates.expand(num_active, -1) active_loob =[active_loo, assumed_covariates], dim=-1) else: active_loob = active_loo Z_active_loo = self.Z[active_loob] F = torch.cholesky_inverse(L_active, upper=False) F_loo = get_loo_inverses(F) if self.Pa > 0: F_loo = F_loo[:-self.Pa] Zt_active_loo = matmul(F_loo, Z_active_loo.unsqueeze(-1)).squeeze(-1) Zt_active_loo_sq = einsum("ij,ij->i", Zt_active_loo, Z_active_loo) if self.prior == 'isotropic': X_I_X_k = leave_one_out_off_diagonal(XX_active).unsqueeze(-1) X_I_X_k = X_I_X_k if self.Pa == 0 else X_I_X_k[:-self.Pa] F_X_I_X_k = matmul(F_loo, X_I_X_k).squeeze(-1) XXFXX = einsum("ij,ij->i", X_I_X_k.squeeze(-1), F_X_I_X_k) XX_active_diag = XX_active.diag() if self.Pa == 0 else XX_active.diag()[:-self.Pa] G_k_inv = XX_active_diag - XXFXX log_det_active = -0.5 * G_k_inv.log() + 0.5 * math.log(self.tau) elif num_active == 1: if self.Pa == 0: Zt_active_loo_sq = 0.0 if self.prior == 'isotropic': log_det_active = -0.5 * torch.log1p(norm(self.X[:, active], dim=0).pow(2.0) / self.tau) else: if self.XX is None: XX_assumed = self.X[:, self.assumed_covariates] XX_assumed = XX_assumed.t() @ XX_assumed else: XX_assumed = self.XX[self.assumed_covariates][:, self.assumed_covariates] if self.prior == "isotropic": XX_assumed.diagonal(dim1=-2, dim2=-1).add_(self.tau) if self.include_intercept: XX_assumed[-1, -1].add_(self.tau_intercept - self.tau) L_assumed = safe_cholesky(XX_assumed) Zt_active_loo = trisolve(L_assumed, self.Z[self.assumed_covariates].unsqueeze(-1), upper=False).squeeze(-1) Zt_active_loo_sq = norm(Zt_active_loo, dim=0).pow(2.0) if self.prior == "isotropic": log_det_active = L_assumed.diagonal().log().sum() - L_active.diagonal().log().sum() log_det_active += 0.5 * math.log(self.tau) elif num_active == 0: Zt_active_loo_sq = 0.0 log_det_active = torch.tensor(0.0, device=self.device, dtype=self.dtype) log_h_ratio_active = sample._log_h_ratio[active] if isinstance(self.h, torch.Tensor) else sample._log_h_ratio log_h_ratio_inactive = sample._log_h_ratio[inactive] if isinstance(self.h, torch.Tensor) \ else sample._log_h_ratio if self.prior == 'gprior': log_S_ratio = -torch.log1p(-self.c_one_c * W_k_sq / (self.YY - self.c_one_c * Zt_active_sq)) log_odds_inactive = log_h_ratio_inactive - self.log_one_c_sqrt + 0.5 * self.N_nu0 * log_S_ratio log_S_ratio = torch.log(self.YY - self.c_one_c * Zt_active_loo_sq) -\ torch.log(self.YY - self.c_one_c * Zt_active_sq) log_odds_active = log_h_ratio_active - self.log_one_c_sqrt + 0.5 * self.N_nu0 * log_S_ratio elif self.prior == 'isotropic': log_S_ratio = -torch.log1p(- W_k_sq / (self.YY - Zt_active_sq)) log_odds_inactive = log_h_ratio_inactive + log_det_inactive + 0.5 * self.N_nu0 * log_S_ratio log_S_ratio = (self.YY - Zt_active_loo_sq).log() - (self.YY - Zt_active_sq).log() log_odds_active = log_h_ratio_active + log_det_active + 0.5 * self.N_nu0 * log_S_ratio log_odds = self.Y.new_full((self.P,), -torch.inf) log_odds[inactive] = log_odds_inactive log_odds[active] = log_odds_active return log_odds def _compute_probs(self, sample): sample._add_prob = sigmoid(self._compute_add_prob(sample)) gamma = sample.gamma.type_as(sample._add_prob) prob_gamma_i = gamma * sample._add_prob + (1.0 - gamma) * (1.0 - sample._add_prob) _i_prob = 0.5 * (sample._add_prob + self.explore) / (prob_gamma_i + self.epsilon) if self.subset_size is not None: _i_prob[self.anchor_subset] *= self.comb_factor i_prob = torch.zeros_like(_i_prob) i_prob[sample._active_subset] = _i_prob[sample._active_subset] sample.pip = sample.gamma.type_as(i_prob) sample.pip[sample._active_subset] = sample._add_prob[sample._active_subset] else: i_prob = _i_prob sample.pip = sample._add_prob if hasattr(self, 'h_alpha') and self.t <= self.T_burnin: # adapt xi xi_comb = self.xi * self.comb_factor self.xi += (self.xi_target - xi_comb / (xi_comb + i_prob.sum())) / math.sqrt(self.t + 1) self.xi.clamp_(min=0.01) sample._i_prob =[self.xi * self.comb_factor, i_prob]) return sample def mcmc_move(self, sample): self.t += 1 sample._idx = Categorical(probs=sample._i_prob).sample() - 1 if sample._idx.item() >= 0: sample.gamma[sample._idx] = ~sample.gamma[sample._idx] sample._active = torch.nonzero(sample.gamma).squeeze(-1) if self.Pa > 0: sample._activeb =[sample._active, self.assumed_covariates]) else: sample = self.sample_alpha_beta(sample) if self.subset_size is not None: sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) sample = self._compute_probs(sample) sample.weight = sample._i_prob.mean().reciprocal() if self.subset_size is not None and self.t <= self.T_burnin: self.pi = sample.weight * sample.pip + self.total_weight * self.pi self.total_weight += sample.weight self.pi /= self.total_weight if (self.t > 99 and self.t % 100 == 0) or self.t == self.T_burnin: self._update_anchor(self.pi.argsort()[-self.anchor_size:]) return sample def sample_alpha_beta(self, sample): num_active = sample._active.size(-1) num_inactive = self.P - num_active sample.h_alpha = torch.tensor(self.h_alpha + num_active, device=self.device) sample.h_beta = torch.tensor(self.h_beta + num_inactive, device=self.device) h = Beta(sample.h_alpha, sample.h_beta).sample().item() sample._log_h_ratio = math.log(h) - math.log(1.0 - h) return sample