MultiBatchSampler

class axtreme.data.multi_dim_batch_sampler.MultiBatchSampler(sampler: Sampler[int] | Iterable[int], batch_shape: Size, partial_batch_dim: None | int = -1)

Bases: Sampler[list[int]]

Reads the entire sampler into batches of shape batch_shape.

The final batch may not have enough samples to completely fill the batch shape. Behaviour is then as follows:

  • If partial_batch_dim is not None, attempt to batch samples allowing this dim to be variable in size. E.g:
    >>> batch_shape = torch.Size([3, 5])
    >>> partial_batch_index = -1
    try: remaining_samples.view([3,-1])
    

This check is performed up front and the sampler will throw and error.

Warning

  • If the batch shape changes, even in the partial_batch_index dimension, data will be returned in a different order.

  • See BatchSampler2d for details.

Todo

  • Determing the right approach for handling partial batch that requires more than one parital index. e.g fit 5 items into batch_shape = torch.Size([3,2])

  • Perhaps return a new batch that is no bigger than the original in any dimension. This batcher is likely for gp throughput, so then the batches only have performance not logical meaning.

__init__(sampler: Sampler[int] | Iterable[int], batch_shape: Size, partial_batch_dim: None | int = -1) None

Allows you to produces batchs of arbitrary shape.

Parameters:
  • sampler (Sampler[int] | Iterable[int]) –

    Sampler (e.g RandomSampler or SerquentialSampler).

    • See torch DataLoader implementation for an examples

  • batch_shape (torch.Size) – The batch shape created of the underling sample.

  • partial_batch_dim

    Dimension of the batch that can be partially filled if there is not enough samples in sampler.

    • Currently only one dimension can be partially filled

    • if None, no dimension is allowed to be partially filled

Methods

__init__(sampler, batch_shape[, ...])

Allows you to produces batchs of arbitrary shape.