PosteriorSampler¶
- class axtreme.sampling.base.PosteriorSampler(*args, **kwargs)¶
Bases:
Protocol
Defines the protocol for sampler function.
This follows the definition of the MCSampler.forward() method, but allows for simple function based implementations.
Note
Function using posteriors should check if there are special mean and var methods if they support posterior samplers that require them (e.g UT). See note below for details.
Notes on samplers the require special aggregation of results:
We use a surrogate model to estimate a Quantity of Interest (QOI), and provide uncertainty of that estimate. Typically it is challenging to calculate this on the posterior directly, so often we take samples of the posterior and calculate the value of interest using those. We can then combine those estimates to give the QOI mean/variance for the overall posterior.
Some methods (e.g UT) select posterior samples in a special way which means less samples are needed. The calculations for each of these samples then need to be combined in a special way to estimate the QOI mean/variance for the overall posterior. If such special methods are needed, they should be implement on the posterior sampler with the following signature.
def mean(self, x: torch.Tensor) -> torch.Tensor: ... def var(self, x: torch.Tensor) -> torch.Tensor: ...
Where
x is the scores (e.g QOIs) for each posterior sample.
It is expected to be of shape (batch_shape)
Todo
Batching gets confusing here, how should we do UT of multiple batch dims? Is it only ever flat, should we say that here?
Output is a scalar.
Todo
Need to fix/revisit this. Feel wrong that functions that use posterior need to check for some optional methods that “might” be there (and those functions are not programmatically defined as part of the protocol). Extra info in #132.
Note
The functions are not on the protocol because we want MCSampler from botorch to fall into protocol definition.
- __init__(*args, **kwargs)¶
Methods
__init__
(*args, **kwargs)