axtreme.distributions.utils

helpers for working with distribution.

Functions

dist_dtype(dist)

Return the dtype the distribution calculates values in.

index_batch_dist(dist, index)

Applies indexing to a MixtureSameFamily object along the batch dimension.

index_batch_mixture_dist(dist, index)

Applies indexing to a MixtureSameFamily object along the batch dimension.

axtreme.distributions.utils.dist_dtype(dist: Distribution) dtype

Return the dtype the distribution calculates values in.

Parameters may be of different tpyes. It appears the distribution defaults to the largest dtype.

Parameters:

dist – the distribution to find the dtype of.

Returns:

dtype the ditribution returns result in.

axtreme.distributions.utils.index_batch_dist(dist: Distribution, index: tuple[slice | int, ...]) Distribution

Applies indexing to a MixtureSameFamily object along the batch dimension.

Parameters:
  • dist – The distribution to be indexed (only along the batch dimensions)

  • index – The index/slice to be applied. e.g (slice(1,4), Ellipsis) is equivalent to [1:4,…]. Slices larger than the batch dimension will cause index error.

Returns:

A “veiw” of the underling distribution. This is a new object, but we call it a veiw as it is built on a view of the underling data.

Note

The returned distribution is built on a view of the underling data. The follows the behaviour of slicing tensors in pytorch as detailed here. As such gradients etc are connected.

axtreme.distributions.utils.index_batch_mixture_dist(dist: MixtureSameFamily, index: tuple[slice | int, ...]) MixtureSameFamily

Applies indexing to a MixtureSameFamily object along the batch dimension.

Parameters:
  • dist – The distribution to be indexed (only along the batch dimensions)

  • index – The index/slice to be applied. e.g (slice(1,4), Ellipsis) is equivalent to [1:4,…]. Slices larger than the batch dimension will cause index error.

Returns:

A “veiw” of the underling distribution. This is a new object, but we call it a veiw as it is built on a view of the underling data.

Note

The returned distribution is built on a view of the underling data. The follows the behaviour of slicing tensors in pytorch as detailed here. As such gradients etc are connected.