axtreme.distributions.utils¶
helpers for working with distribution.
Functions
|
Return the dtype the distribution calculates values in. |
|
Applies indexing to a MixtureSameFamily object along the batch dimension. |
|
Applies indexing to a MixtureSameFamily object along the batch dimension. |
- axtreme.distributions.utils.dist_dtype(dist: Distribution) dtype ¶
Return the dtype the distribution calculates values in.
Parameters may be of different tpyes. It appears the distribution defaults to the largest dtype.
- Parameters:
dist – the distribution to find the dtype of.
- Returns:
dtype the ditribution returns result in.
- axtreme.distributions.utils.index_batch_dist(dist: Distribution, index: tuple[slice | int, ...]) Distribution ¶
Applies indexing to a MixtureSameFamily object along the batch dimension.
- Parameters:
dist – The distribution to be indexed (only along the batch dimensions)
index – The index/slice to be applied. e.g (slice(1,4), Ellipsis) is equivalent to [1:4,…]. Slices larger than the batch dimension will cause index error.
- Returns:
A “veiw” of the underling distribution. This is a new object, but we call it a veiw as it is built on a view of the underling data.
Note
The returned distribution is built on a view of the underling data. The follows the behaviour of slicing tensors in pytorch as detailed here. As such gradients etc are connected.
- axtreme.distributions.utils.index_batch_mixture_dist(dist: MixtureSameFamily, index: tuple[slice | int, ...]) MixtureSameFamily ¶
Applies indexing to a MixtureSameFamily object along the batch dimension.
- Parameters:
dist – The distribution to be indexed (only along the batch dimensions)
index – The index/slice to be applied. e.g (slice(1,4), Ellipsis) is equivalent to [1:4,…]. Slices larger than the batch dimension will cause index error.
- Returns:
A “veiw” of the underling distribution. This is a new object, but we call it a veiw as it is built on a view of the underling data.
Note
The returned distribution is built on a view of the underling data. The follows the behaviour of slicing tensors in pytorch as detailed here. As such gradients etc are connected.