MultiBatchSampler¶
- class axtreme.data.multi_dim_batch_sampler.MultiBatchSampler(sampler: Sampler[int] | Iterable[int], batch_shape: Size, partial_batch_dim: None | int = -1)¶
Bases:
Sampler
[list
[int
]]Reads the entire sampler into batches of shape batch_shape.
The final batch may not have enough samples to completely fill the batch shape. Behaviour is then as follows:
- If partial_batch_dim is not None, attempt to batch samples allowing this dim to be variable in size. E.g:
>>> batch_shape = torch.Size([3, 5]) >>> partial_batch_index = -1 try: remaining_samples.view([3,-1])
This check is performed up front and the sampler will throw and error.
Warning
If the batch shape changes, even in the partial_batch_index dimension, data will be returned in a different order.
See BatchSampler2d for details.
Todo
Determing the right approach for handling partial batch that requires more than one parital index. e.g fit 5 items into batch_shape = torch.Size([3,2])
Perhaps return a new batch that is no bigger than the original in any dimension. This batcher is likely for gp throughput, so then the batches only have performance not logical meaning.
- __init__(sampler: Sampler[int] | Iterable[int], batch_shape: Size, partial_batch_dim: None | int = -1) None ¶
Allows you to produces batchs of arbitrary shape.
- Parameters:
sampler (Sampler[int] | Iterable[int]) –
Sampler (e.g RandomSampler or SerquentialSampler).
See torch DataLoader implementation for an examples
batch_shape (torch.Size) – The batch shape created of the underling sample.
partial_batch_dim –
Dimension of the batch that can be partially filled if there is not enough samples in sampler.
Currently only one dimension can be partially filled
if None, no dimension is allowed to be partially filled
Methods
__init__
(sampler, batch_shape[, ...])Allows you to produces batchs of arbitrary shape.