Source code for millipede.binomial

import math
from types import SimpleNamespace

import numpy as np
import torch
from polyagamma import random_polyagamma
from torch import cholesky_solve as chosolve
from torch import dot, einsum, matmul, sigmoid
from torch.distributions import Beta, Categorical, Uniform
from torch.linalg import norm
from torch.linalg import solve_triangular as trisolve
from torch.nn.functional import softplus

from .sampler import MCMCSampler
from .util import (
    arange_complement,
    get_loo_inverses,
    leave_one_out,
    leave_one_out_off_diagonal,
    safe_cholesky,
    sample_active_subset,
)


[docs]class CountLikelihoodSampler(MCMCSampler): r""" MCMC algorithm for Bayesian variable selection for a generalized linear model with a Binomial or Negative Binomial likelihood (note that a Bernoulli likelihood is a special case of a Binomial likelihood). This class supports count-valued responses. To define a Binomial model specify `TC` but not `psi0`. To define a Negative Binomial model specify `psi0` but not `TC`. The details of the available models in :class:`CountlLikelihoodSampler` are as follows. For both likelihoods the covariates and responses are defined as: .. math:: X \in \mathbb{R}^{N \times P} \qquad \qquad Y \in \mathbb{Z}_{\ge 0}^{N} The inclusion of each covariate is governed by a Bernoulli random variable :math:`\gamma_p`. In particular :math:`\gamma_p = 0` corresponds to exclusion and :math:`\gamma_p = 1` corresponds to inclusion. The prior probability of inclusion is governed by :math:`h` or alternatively :math:`S`: .. math:: h \in [0, 1] \qquad \rm{with} \qquad S \equiv hP Alternatively, if :math:`h` is not known a priori we can put a prior on :math:`h`: .. math:: h \sim {\rm Beta}(\alpha, \beta) \qquad \rm{with} \qquad \alpha > 0 \;\;\;\; \beta > 0 Putting this together, the model specification for the Binomial case is as follows: .. math:: &\gamma_p \sim \rm{Bernoulli}(h) \qquad \rm{for} \qquad p=1,2,...,P &\beta_0 \sim \rm{Normal}(0, \tau_\rm{intercept}^{-1}) &\beta_\gamma \sim \rm{Normal}(0, \tau^{-1} \mathbb{1}_\gamma) &Y_n \sim \rm{Binomial}(T_n, \sigma(\beta_0 + X_{n, \gamma} \cdot \beta_\gamma)) where :math:`\sigma(\cdot)` is the logistic or sigmoid function and :math:`T_n` denotes the :math:`N`-dimensional vector of total counts. That is each Binomial likelihood is equivalent to :math:`T_n` corresponding Bernoulli likelihoods. The Negative Binomial case is similar but includes a latent variable :math:`\nu > 0` that governs the dispersion of the Negative Binomial distribution: .. math:: &\log \nu \sim \rm{ImproperPrior}(-\infty, \infty) &Y_n \sim \rm{NegBinomial}(\rm{mean}=\rm{exp}(\beta_0 + X_{n, \gamma} \cdot \beta_\gamma + \psi_{0, n}), \nu) The vector :math:`\psi_0 \in \mathbb{R}^N` allows the user to supply a datapoint-specific offset. We note that we use a parameterization of the Negative Binomial distribution where the variance is given by .. math:: \rm{variance} = \rm{mean} + \rm{mean}^2 / \nu so that small values of :math:`\nu` correspond to large dispersion/variance and :math:`\nu \to \infty` recovers the Poisson distribution. Note that above the dimension of :math:`\beta_\gamma` depends on the number of covariates included in a particular model (i.e. on the number of non-zero entries in :math:`\gamma`). Usage of this class is only recommended for advanced users. For most users it should suffice to use one of :class:`BinomialLikelihoodVariableSelector`, :class:`BernoulliLikelihoodVariableSelector`, and :class:`NegativeBinomialLikelihoodVariableSelector`. :param tensor X: A N x P `torch.Tensor` of covariates. This is a required argument. :param tensor Y: A N-dimensional `torch.Tensor` of non-negative count-valued responses. This is a required argument. :param tensor X_assumed: A N x P' `torch.Tensor` of covariates that are always assumed to be part of the model. Defaults to `None`. :param tensor TC: A N-dimensional `torch.Tensor` of non-negative counts. This is a required argument if you wish to specify a Binomial model. Defaults to None. :param tensor psi0: A N-dimensional `torch.Tensor` of offsets `psi0`. This is a required argument if you wish to specify a Negative Binomial model. If the user specifies a float, `psi0` will be expanded to a N-dimensional vector internally. Defaults to None. :param S: Controls the expected number of covariates to include in the model a priori. Defaults to 5.0. To specify covariate-level prior inclusion probabilities provide a P-dimensional `torch.Tensor` of the form `(h_1, ..., h_P)`. If a tuple of positive floats `(alpha, beta)` is provided, the a priori inclusion probability is a latent variable governed by the corresponding Beta prior so that the sparsity level is inferred from the data. Note that for a given choice of `alpha` and `beta` the expected number of covariates to include in the model a priori is given by :math:`\frac{\alpha}{\alpha + \beta} \times P`. Also note that the mean number of covariates in the posterior can vary significantly from prior expectations, since the posterior is in effect a compromise between the prior and the observed data. :param float tau: Controls the precision of the coefficients in the isotropic prior. Defaults to 0.01. :param float tau_intercept: Controls the precision of the intercept in the isotropic prior. Defaults to 1.0e-4. :param float explore: This hyperparameter controls how greedy the MCMC algorithm is. Defaults to 5.0. :param float log_nu_rw_scale: This hyperparameter controls the proposal distribution for `nu` updates. Defaults to 0.05. Only applicable to the Negative Binomial case. :param bool omega_mh: Whether to include Metropolis-Hastings corrections during Polya-Gamma updates. Defaults to True. Only applicable to the Binomial case. :param float xi_target: This hyperparameter controls how frequently the MCMC algorithm makes Polya-Gamma updates. It also controls how often :math:`h` updates are made if :math:`h` is a latent variable. Defaults to 0.25. :param float init_nu: This hyperparameter controls the initial value of the dispersion parameter `nu`. Defaults to 5.0. Only applicable to the Negative Binomial case. :param bool verbose_constructor: Whether the class constructor should print some information to stdout upon initialization. :param int subset_size: If `subset_size` is not None `subset_size` controls the amount of computational resources to use in Subset PG-wTGS. Otherwise if `subset_size` is None vanilla PG-wTGS is used. This argument is intended to be used for datasets with a very large number of covariates (e.g. tens of thousands or more). A typical value might be ~5-10% of the total number of covariates; smaller values result in more MCMC iterations per second but may lead to high variance PIP estimates. Defaults to None. :param int anchor_size: If `subset_size` is not None `anchor_size` controls how greedy Subset PG-wTGS is. If `anchor_size` is None it defaults to half of `subset_size`. For expert users only. Defaults to None. """ def __init__(self, X, Y, X_assumed=None, TC=None, psi0=None, S=5.0, tau=0.01, tau_intercept=1.0e-4, explore=5.0, log_nu_rw_scale=0.05, omega_mh=True, xi_target=0.25, init_nu=5.0, verbose_constructor=True, subset_size=None, anchor_size=None): super().__init__() if not ((TC is None and psi0 is not None) or (TC is not None and psi0 is None)): raise ValueError('CountLikelihoodSampler supports two modes of operation. ' + 'In order to specify a binomial likelihood the user must provide TC but ~not~ ' + 'provide psi0. For a negative binomial likelihood the user must provide psi0 ' + 'but ~not~ provide TC.') self.dtype = X.dtype self.device = X.device self.Xb = X self.N, self.P = X.shape if X_assumed is not None: assert self.dtype == X_assumed.dtype assert self.device == X_assumed.device if X.size(0) != X_assumed.size(0): raise ValueError("X and X_assumed must have the same number of rows.") assert X_assumed.size(-1) > 0 self.Xb = torch.cat([self.Xb, X_assumed], dim=-1) if subset_size is not None and (subset_size <= 1 or subset_size >= self.P): raise ValueError("If subset_size is not None must be strictly between 1 and P, the number of covariates.") self.subset_size = subset_size if anchor_size is not None: if subset_size is None: raise ValueError("The anchor_size argument should only be used if subset_size is not None.") if anchor_size < 1 or anchor_size >= subset_size: raise ValueError("anchor_size should be strictly between 0 and subset_size.") self.Xb = torch.cat([self.Xb, X.new_ones(X.size(0), 1)], dim=-1) self.Y = Y self.Y_float = self.Y.type_as(X) self.tau = tau self.negbin = psi0 is not None if self.negbin: psi0 = X.new_tensor(psi0) if isinstance(psi0, float) else psi0 if not (psi0.shape == Y.shape or psi0.shape == ()): raise ValueError("psi0 should either be a scalar or a one-dimensional array with " + "the same number of elements as Y.") if init_nu <= 0.0: raise ValueError("init_nu must be positive.") self.init_nu = init_nu self.psi0 = psi0 self.log_nu_rw_scale = log_nu_rw_scale else: if not Y.shape == TC.shape or Y.ndim != 1: raise ValueError("Y and TC should both be one-dimensional arrays.") self.TC = TC self.TC_np = TC.data.cpu().numpy().copy() self.TC_float = self.TC.type_as(X) if self.N != Y.size(-1): raise ValueError("X and Y should be of shape (N, P) and (N,), respectively.") S = S if not isinstance(S, int) else float(S) if isinstance(S, float): if S >= self.P or S <= 0: raise ValueError("S must satisfy 0 < S < P or must be a tuple or tensor.") elif isinstance(S, tuple): if len(S) != 2 or not isinstance(S[0], float) or not isinstance(S[1], float) or S[0] <= 0.0 or S[1] <= 0.0: raise ValueError("If S is a tuple it must be a tuple of two positive floats (alpha, beta).") elif isinstance(S, torch.Tensor): if S.shape != (self.P,) or (S >= 1.0).any().item() or (S <= 0.0).any().item(): raise ValueError("If S is a tensor it must be P-dimensional and all elements must be strictly" + " contained in (0, 1).") else: raise ValueError("S must be a float, tuple or tensor.") if tau <= 0.0: raise ValueError("tau must be positive.") if tau_intercept <= 0.0: raise ValueError("tau_intercept must be positive.") if explore < 0.0: raise ValueError("tau must be non-negative.") if log_nu_rw_scale < 0.0 and self.negbin: raise ValueError("log_nu_rw_scale must be non-negative.") if xi_target <= 0.0 or xi_target >= 1.0: raise ValueError("xi_target must be in the interval (0, 1).") if X_assumed is not None: self.assumed_covariates = torch.arange(self.P, self.P + X_assumed.size(-1) + 1, device=self.device) else: self.assumed_covariates = torch.tensor([self.P], device=self.device, dtype=torch.int64) self.Pa = self.assumed_covariates.size(-1) if isinstance(S, float): self.h = S / self.P self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) elif isinstance(S, tuple): self.h_alpha, self.h_beta = S self.h = self.h_alpha / (self.h_alpha + self.h_beta) self.log_h_ratio = math.log(self.h) - math.log(1.0 - self.h) else: self.h = S self.log_h_ratio = S.log() - torch.log1p(-S) self.explore = explore / self.P self.half_log_tau = 0.5 * math.log(tau) self.tau_intercept = tau_intercept self.epsilon = 1.0e3 * torch.finfo(X.dtype).tiny self.xi = torch.tensor([5.0], device=X.device) self.xi_target = xi_target self.omega_mh = omega_mh self.uniform_dist = Uniform(0.0, X.new_ones(1)[0]) if self.subset_size is not None: self.anchor_size = subset_size // 2 if anchor_size is None else anchor_size self.pi = X.new_ones(self.P) * self.h if isinstance(S, (float, tuple)) else self.h self.total_weight = 0.0 self.comb_factor = (self.subset_size - self.anchor_size) / (self.P - self.anchor_size) else: self.comb_factor = 1.0 if verbose_constructor: s1 = "Initialized CountLikelihoodSampler with {} likelihood and (N, P, S, epsilon, subset_size) = " s2 = "({}, {}, {:.1f}, {:.1f}, {})" if not isinstance(S, tuple) \ else "({}, {}, ({:.1f}, {:.1f}), {:.1f}, {})" if isinstance(S, float): S = (S,) elif isinstance(S, torch.Tensor): S = (S.min().item(), S.max().item()) print((s1 + s2).format("Negative Binomial" if self.negbin else "Binomial", self.N, self.P, *S, explore, subset_size)) def initialize_sample(self, seed=None): self.accepted_omega_updates = 0 self.attempted_omega_updates = 0 self.acceptance_probs = [] self.rng = np.random.default_rng(seed) if seed is not None: torch.manual_seed(seed) if not self.negbin: log_nu = None _omega = torch.from_numpy(random_polyagamma(self.TC_np, random_state=self.rng)).type_as(self.Xb) else: log_nu = torch.tensor(math.log(self.init_nu)) _omega = torch.from_numpy(random_polyagamma(self.Y.data.cpu().numpy() + self.init_nu, random_state=self.rng)).type_as(self.Xb) _psi0 = self.psi0 - log_nu if self.negbin else 0.0 _kappa = 0.5 * (self.Y - log_nu.exp()) if self.negbin else self.Y - 0.5 * self.TC _kappa_omega = _kappa - _omega * _psi0 _Z = einsum("np,n->p", self.Xb, _kappa_omega) sample = SimpleNamespace(gamma=self.Xb.new_zeros(self.P).bool(), _omega=_omega, beta=self.Xb.new_zeros(self.P + self.Pa), beta_mean=self.Xb.new_zeros(self.P + self.Pa), _psi0=_psi0, _idx=torch.randint(self.P, (), device=self.device), weight=0, log_nu=log_nu, _kappa=_kappa, _kappa_omega=_kappa_omega, _Z=_Z, _log_h_ratio=self.log_h_ratio, _active=torch.tensor([], dtype=torch.int64), _activeb=self.assumed_covariates) if self.subset_size is not None: Z_cent = einsum("np,n->p", self.Xb[:, :self.P], self.Y - self.Y.mean()) self._update_anchor(Z_cent.abs().argsort()[-self.anchor_size:]) sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) if hasattr(self, "h_alpha"): sample.h_alpha = torch.tensor(self.h_alpha, device=self.device) sample.h_beta = torch.tensor(self.h_beta, device=self.device) sample = self.sample_beta(sample) # populate self._L_active sample = self._compute_probs(sample) return sample def _update_anchor(self, anchor): self.anchor_subset = anchor self.anchor_subset_set = set(anchor.data.cpu().numpy().tolist()) self.anchor_complement = arange_complement(self.P, anchor) def _compute_add_prob(self, sample): active, activeb = sample._active, sample._activeb if self.subset_size is not None: inactive = torch.zeros(self.P, device=self.device, dtype=torch.bool) inactive[sample._active_subset] = ~(sample.gamma[sample._active_subset]) inactive = torch.nonzero(inactive).squeeze(-1) else: inactive = torch.nonzero(~sample.gamma).squeeze(-1) num_active = active.size(-1) assert num_active < self.P, "The MCMC sampler has been driven into a regime where " +\ "all covariates have been selected. Are you sure you have chosen a reasonable prior? " +\ "Are you sure there is signal in your data?" # TODO: do we need to compute all of this if subset_size is not None? X_omega = self.Xb * sample._omega.sqrt().unsqueeze(-1) X_omega_k = X_omega[:, inactive] Z_k = sample._Z[inactive] X_omega_active = X_omega[:, activeb] Z_active = sample._Z[activeb] Zt_active = trisolve(self._L_active, Z_active.unsqueeze(-1), upper=False).squeeze(-1) Xt_active = trisolve(self._L_active, X_omega_active.t(), upper=False).t() XtZt_active = einsum("np,p->n", Xt_active, Zt_active) XX_k = norm(X_omega_k, dim=0).pow(2.0) G_k_inv = XX_k + self.tau - norm(einsum("ni,nk->ik", Xt_active, X_omega_k), dim=0).pow(2.0) W_k_sq = (einsum("np,n->p", X_omega_k, XtZt_active) - Z_k).pow(2.0) / (G_k_inv + self.epsilon) log_det_ratio_inactive = -0.5 * G_k_inv.log() + self.half_log_tau if num_active > 1: active_loo = leave_one_out(active) # I I-1 active_loob = torch.cat([active_loo, self.assumed_covariates.expand(num_active, -1)], dim=-1) Z_active_loo = sample._Z[active_loob] F = torch.cholesky_inverse(self._L_active, upper=False) F_loo = get_loo_inverses(F)[:-self.Pa] Zt_active_loo = matmul(F_loo, Z_active_loo.unsqueeze(-1)).squeeze(-1) Zt_active_loo_sq = einsum("ij,ij->i", Zt_active_loo, Z_active_loo) X_I_X_k = leave_one_out_off_diagonal(self._precision).unsqueeze(-1)[:-self.Pa] F_X_I_X_k = matmul(F_loo, X_I_X_k).squeeze(-1) XXFXX = einsum("ij,ij->i", X_I_X_k.squeeze(-1), F_X_I_X_k) G_k_inv = norm(X_omega[:, active], dim=0).pow(2.0) + self.tau - XXFXX log_det_ratio_active = -0.5 * G_k_inv.log() + self.half_log_tau elif num_active == 1: XX_assumed = self._precision[-self.Pa:][:, -self.Pa:] L_assumed = safe_cholesky(XX_assumed) Zt_active_loo = trisolve(L_assumed, sample._Z[self.assumed_covariates].unsqueeze(-1), upper=False).squeeze(-1) Zt_active_loo_sq = norm(Zt_active_loo, dim=0).pow(2.0) log_det_ratio_active = L_assumed.diagonal().log().sum() - self._L_active.diagonal().log().sum() log_det_ratio_active += 0.5 * math.log(self.tau) elif num_active == 0: Zt_active_loo_sq = 0.0 # dummy values since no active covariates log_det_ratio_active = torch.tensor(0.0) log_h_ratio_active = sample._log_h_ratio[active] if isinstance(self.h, torch.Tensor) else sample._log_h_ratio log_h_ratio_inactive = sample._log_h_ratio[inactive] if isinstance(self.h, torch.Tensor) \ else sample._log_h_ratio log_odds_inactive = 0.5 * W_k_sq + log_det_ratio_inactive + log_h_ratio_inactive log_odds_active = 0.5 * (Zt_active.pow(2.0).sum() - Zt_active_loo_sq) + \ log_det_ratio_active + log_h_ratio_active log_odds = self.Xb.new_full((self.P,), -torch.inf) log_odds[inactive] = log_odds_inactive log_odds[active] = log_odds_active return log_odds def _compute_probs(self, sample): sample._add_prob = sigmoid(self._compute_add_prob(sample)) gamma = sample.gamma.type_as(sample._add_prob) prob_gamma_i = gamma * sample._add_prob + (1.0 - gamma) * (1.0 - sample._add_prob) _i_prob = 0.5 * (sample._add_prob + self.explore) / (prob_gamma_i + self.epsilon) if self.subset_size is not None: _i_prob[self.anchor_subset] *= self.comb_factor i_prob = torch.zeros_like(_i_prob) i_prob[sample._active_subset] = _i_prob[sample._active_subset] sample.pip = sample.gamma.type_as(i_prob) sample.pip[sample._active_subset] = sample._add_prob[sample._active_subset] else: i_prob = _i_prob sample.pip = sample._add_prob if self.t <= self.T_burnin: # adapt xi xi_comb = self.xi * self.comb_factor self.xi += (self.xi_target - xi_comb / (xi_comb + i_prob.sum())) / math.sqrt(self.t + 1) self.xi.clamp_(min=0.01) sample._i_prob = torch.cat([self.xi * self.comb_factor, i_prob]) return sample def mcmc_move(self, sample): self.t += 1 sample._idx = Categorical(probs=sample._i_prob).sample() - 1 if sample._idx.item() >= 0: sample.gamma[sample._idx] = ~sample.gamma[sample._idx] sample._active = torch.nonzero(sample.gamma).squeeze(-1) sample._activeb = torch.cat([sample._active, self.assumed_covariates], dim=-1) sample = self.sample_beta(sample) else: if hasattr(self, 'h_alpha'): if torch.rand(1).item() < 0.50: sample = self.sample_alpha_beta(sample) sample = self.sample_omega_nb(sample) if self.negbin else self.sample_omega_binomial(sample) else: sample = self.sample_omega_nb(sample) if self.negbin else self.sample_omega_binomial(sample) sample = self.sample_alpha_beta(sample) else: sample = self.sample_omega_nb(sample) if self.negbin else self.sample_omega_binomial(sample) if self.subset_size is not None: sample._active_subset = sample_active_subset(self.P, self.subset_size, self.anchor_subset, self.anchor_subset_set, self.anchor_complement, sample._idx) sample = self._compute_probs(sample) sample.weight = sample._i_prob.mean().reciprocal() if self.subset_size is not None and self.t <= self.T_burnin: self.pi = sample.weight * sample.pip + self.total_weight * self.pi self.total_weight += sample.weight self.pi /= self.total_weight if (self.t > 99 and self.t % 100 == 0) or self.t == self.T_burnin: self._update_anchor(self.pi.argsort()[-self.anchor_size:]) return sample def sample_beta(self, sample): activeb = sample._activeb Xb_active = self.Xb[:, activeb] precision = Xb_active.t() @ (sample._omega.unsqueeze(-1) * Xb_active) precision.diagonal(dim1=-2, dim2=-1).add_(self.tau) precision[-1, -1].add_(self.tau_intercept - self.tau) self._L_active = safe_cholesky(precision) self._precision = precision sample.beta.zero_() sample.beta_mean.zero_() beta_active = chosolve(sample._Z[activeb].unsqueeze(-1), self._L_active).squeeze(-1) sample.beta_mean[activeb] = beta_active sample.beta[activeb] = beta_active + \ trisolve(self._L_active, torch.randn(activeb.size(-1), 1, device=self.device, dtype=self.dtype), upper=False).squeeze(-1) sample._psi = torch.mv(Xb_active, beta_active) return sample def sample_omega_binomial(self, sample, _save_intermediates=None): omega_prop = random_polyagamma(self.TC_np, sample._psi.data.cpu().numpy(), random_state=self.rng) omega_prop = torch.from_numpy(omega_prop).type_as(self.Xb) activeb = sample._activeb Xb_active = self.Xb[:, activeb] # some of these computations could be reused/saved but they are cheap # so we do them from scratch to avoid unnecessary complexity def compute_log_target(omega): precision = Xb_active.t() @ (omega.unsqueeze(-1) * Xb_active) precision.diagonal(dim1=-2, dim2=-1).add_(self.tau) precision[-1, -1].add_(self.tau_intercept - self.tau) L = safe_cholesky(precision) LZ = trisolve(L, sample._Z[activeb].unsqueeze(-1), upper=False).squeeze(-1) logdet = L.diag().log().sum() - L.size(-1) * self.half_log_tau return 0.5 * norm(LZ, dim=0).pow(2.0) - logdet, L, precision log_target_prop, L_prop, precision_prop = compute_log_target(omega_prop) beta_mean_prop = chosolve(sample._Z[activeb].unsqueeze(-1), L_prop).squeeze(-1) beta_prop = beta_mean_prop + \ trisolve(L_prop, torch.randn(activeb.size(-1), 1, device=self.device, dtype=self.dtype), upper=False).squeeze(-1) psi_prop = torch.mv(Xb_active, beta_mean_prop) if self.omega_mh: log_target_curr, _, _ = compute_log_target(sample._omega) delta_psi = psi_prop - sample._psi accept1 = log_target_prop - log_target_curr accept2 = dot(sample._kappa - self.Y_float, delta_psi) accept3 = 0.5 * (dot(omega_prop, sample._psi.pow(2.0)) - dot(sample._omega, psi_prop.pow(2.0))) accept4 = dot(self.TC_float, softplus(psi_prop) - softplus(sample._psi)) accept = min(1.0, (accept1 + accept2 + accept3 + accept4).exp().item()) if _save_intermediates is not None: _save_intermediates['omega'] = sample._omega.data.cpu().numpy().copy() _save_intermediates['omega_prop'] = omega_prop.data.cpu().numpy().copy() _save_intermediates['psi'] = sample._psi.data.cpu().numpy().copy() _save_intermediates['psi_prop'] = psi_prop.data.cpu().numpy().copy() _save_intermediates['TC_np'] = self.TC_np.copy() _save_intermediates['accept234'] = accept2 + accept3 + accept4 if self.t >= self.T_burnin: self.acceptance_probs.append(accept) accept = self.uniform_dist.sample().item() < accept if self.t >= self.T_burnin: self.attempted_omega_updates += 1 self.accepted_omega_updates += int(accept) elif self.t >= self.T_burnin: # always accept mh move self.accepted_omega_updates += 1 self.attempted_omega_updates += 1 if not self.omega_mh or accept or (self.t < self.T_burnin // 2): sample._omega = omega_prop sample._psi = psi_prop sample.beta_mean.zero_() sample.beta_mean[activeb] = beta_mean_prop sample.beta.zero_() sample.beta[activeb] = beta_prop self._L_active = L_prop self._precision = precision_prop return sample def sample_omega_nb(self, sample, _save_intermediates=None): activeb = sample._activeb Xb_active = self.Xb[:, activeb] log_nu_prop = sample.log_nu + self.log_nu_rw_scale * torch.randn(1).item() nu_curr, nu_prop = sample.log_nu.exp(), log_nu_prop.exp() T_curr, T_prop = self.Y + nu_curr, self.Y + nu_prop psi0_prop = self.psi0 - log_nu_prop psi_mixed = sample._psi + psi0_prop omega_prop = random_polyagamma(T_prop.data.cpu().numpy(), psi_mixed.data.cpu().numpy(), random_state=self.rng) omega_prop = torch.from_numpy(omega_prop).type_as(self.Xb) kappa_prop = 0.5 * (self.Y - nu_prop) kappa_omega_prop = kappa_prop - omega_prop * psi0_prop Z_prop = einsum("np,n->p", self.Xb, kappa_omega_prop) def compute_log_target(omega, Z): precision = Xb_active.t() @ (omega.unsqueeze(-1) * Xb_active) precision.diagonal(dim1=-2, dim2=-1).add_(self.tau) precision[-1, -1].add_(self.tau_intercept - self.tau) L = safe_cholesky(precision) LZ = trisolve(L, Z[activeb].unsqueeze(-1), upper=False).squeeze(-1) logdet = L.diag().log().sum() - L.size(-1) * self.half_log_tau return 0.5 * norm(LZ, dim=0).pow(2.0) - logdet, L, precision log_target_prop, L_prop, precision_prop = compute_log_target(omega_prop, Z_prop) beta_mean_prop = chosolve(Z_prop[activeb].unsqueeze(-1), L_prop).squeeze(-1) beta_prop = beta_mean_prop + \ trisolve(L_prop, torch.randn(activeb.size(-1), 1, device=self.device, dtype=self.dtype), upper=False).squeeze(-1) psi_prop = torch.mv(Xb_active, beta_mean_prop) log_target_curr, _, _ = compute_log_target(sample._omega, sample._Z) accept1 = log_target_prop - log_target_curr \ + (torch.lgamma(T_prop) - torch.lgamma(nu_prop)).sum() \ - (torch.lgamma(T_curr) - torch.lgamma(nu_curr)).sum() \ + (kappa_prop * psi0_prop).sum() - (sample._kappa * sample._psi0).sum() \ + 0.5 * ((sample._omega * sample._psi0.pow(2.0)).sum() - (omega_prop * psi0_prop.pow(2.0)).sum()) psi_mixed_prop = psi_prop + sample._psi0 accept2 = dot(sample._kappa, psi_mixed_prop) - dot(kappa_prop, psi_mixed) \ + 0.5 * (dot(omega_prop, psi_mixed.pow(2.0)) - dot(sample._omega, psi_mixed_prop.pow(2.0))) accept3 = dot(self.Y_float, psi_mixed - psi_mixed_prop) \ - dot(T_prop, softplus(psi_mixed)) + dot(T_curr, softplus(psi_mixed_prop)) accept = min(1.0, (accept1 + accept2 + accept3).exp().item()) if _save_intermediates is not None: _save_intermediates['omega'] = sample._omega.data.cpu().numpy() _save_intermediates['omega_prop'] = omega_prop.data.cpu().numpy() _save_intermediates['psi_mixed'] = psi_mixed.data.cpu().numpy() _save_intermediates['psi_mixed_prop'] = psi_mixed_prop.data.cpu().numpy() _save_intermediates['T_curr'] = T_curr.data.cpu().numpy() _save_intermediates['T_prop'] = T_prop.data.cpu().numpy() _save_intermediates['delta_nu'] = nu_curr.item() - nu_prop.item() _save_intermediates['accept23'] = accept2 + accept3 if self.t >= self.T_burnin: self.acceptance_probs.append(accept) accept = self.uniform_dist.sample().item() < accept if self.t >= self.T_burnin: self.attempted_omega_updates += 1 self.accepted_omega_updates += int(accept) if accept or self.t < min(50, self.T_burnin // 4): sample.log_nu = log_nu_prop sample._omega = omega_prop sample._psi = psi_prop self._L_active = L_prop self._precision = precision_prop sample._kappa = kappa_prop sample._psi0 = psi0_prop sample._kappa_omega = kappa_omega_prop sample._Z = Z_prop sample.beta_mean.zero_() sample.beta_mean[activeb] = beta_mean_prop sample.beta.zero_() sample.beta[activeb] = beta_prop return sample def sample_alpha_beta(self, sample): num_active = sample._active.size(-1) num_inactive = self.P - num_active sample.h_alpha = torch.tensor(self.h_alpha + num_active, device=self.Xb.device) sample.h_beta = torch.tensor(self.h_beta + num_inactive, device=self.Xb.device) h = Beta(sample.h_alpha, sample.h_beta).sample().item() sample._log_h_ratio = math.log(h) - math.log(1.0 - h) return sample