axtreme.data.batch_invariant_samplerΒΆ
Contains samplers where the total dataset procuded is not effect by the size of the batch dimension used.
Batch invariance in 1d. This can be achieved using the standard BatchSampler, for example:
>>> from torch.utils.data import BatchSampler
>>> data = [1, 2, 3, 4, 5]
>>> list(BatchSampler(data, batch_size=3, drop_last=False))
[1, 2, 3], [4, 5, 6]]
>>> list(BatchSampler(data, batch_size=2, drop_last=False))
[[1, 2], [3, 4], [5]]
Regardless of batch size, these results can be concatenated along the batch dimension to prodcude the same result.
Batch invariance in 2d:
>>> data = [1, 2, 3, 4, 5, 6]
We want to turn things into a 2d dataset, where the dimension being batched along (e.g rows) does not effect the
final data produced. This will not be the case with MultiBatchSampler:
>>> list(MultiBatchSampler(data, batch_shape=torch.Size([2, 3])))
[[
[1, 2, 3],
[4, 5, 6]
]]
>>> b1, b2, b3 = list(MultiBatchSampler(data, batch_shape=torch.Size([2, 1])))
>>> print(f"{b1=},{b2=},{b3=})
b1= [[1], b2= [[3], b3= [[5],
[2]] [4]] [6]]
>>> concat([b1,b2,b3], axis = -1)
[[ 1, 3, 5]
[ 2, 4, 6]]
The datasets produced will put data in different location based on the batch shape. This is aproblem if you will be aggrgating (e.g over rows), and were expecting the batches to be invariant over that dimension.
Provide invariante batching in a specific dimension.
Classes
|
Returns 2d batchs where the final dataset in invariant to changes in the last batch dimension. |