BatchInvariantSampler2d

class axtreme.data.batch_invariant_sampler.BatchInvariantSampler2d(sampler: Sampler[int] | Iterable[int], batch_shape: Size)

Bases: Sampler[list[list[int]]]

Returns 2d batchs where the final dataset in invariant to changes in the last batch dimension.

The standard BatchSampler
  • has 1 row that it batches in sizes b.

  • returns items of shape (1) * b

This BatchSampler has:
  • n rows that are batched in size b.

  • returns batches of n * b

Importantly, the concatenated batches (along the last dimension) will alway the same 2d matrix, regardless of the bach size.

Examples

>>> Conceptually the full dataset might have the following indexs
[[ 1, 3, 5]
 [ 2, 4, 6]]
>>> If batched along index = -1, with batch size 2 will return results as follows
b1= [[1,3],   b2= [[5],
     [2,4]]]       [6]]
>>> concat([b1, b2], axis=-1)
[[ 1, 3, 5]
 [ 2, 4, 6]]

This is important when the final dataset produced should be invariate to changes in b.

Note

  • The final batch can return a partial batch (n rows, less than b columns)

  • Related to issue #76

Todo

This is a very specific case because its cater to general case.
  • Can we make a more general 2d case? The batched dim needs to be filled last. Any aggregation is then expected on the batch dim

  • Can we make a general higher dim case?

  • Should we just take multiple samplers and batch them in parallel?
    • Pro: Might be cleaner conceptually

    • Con: When we treat as a dataset we want to get through all data before we start repeating data

__init__(sampler: Sampler[int] | Iterable[int], batch_shape: Size) None

Will produce batch (along the rows) of a 2d batch.

Parameters:
  • sampler (Sampler[int] | Iterable[int]) – Sampler (e.g RandomSampler or SerquentialSampler) to be batched. - Must be Sized (e.g have __len__) - See torch DataLoader implmentation for an examples

  • batch_shape (torch.Size) – The batch shape created of the underlieing sample.

Methods

__init__(sampler, batch_shape)

Will produce batch (along the rows) of a 2d batch.