Skip to content

Data Splitters


DataSplitter*args, **kwargs)

Base class for all data splitters.


RandomSplitter, seed=None)

Randomly splits items.


  • probs Sequence[int]: Sequence of probabilities that must sum to one. The length of the Sequence is the number of groups to to split the items into.
  • seed int: Internal seed used for shuffling the items. Define this if you need reproducible results.


Split data into three random groups.

idmap = IDMap(["file1", "file2", "file3", "file4"])

data_splitter = RandomSplitter([0.6, 0.2, 0.2], seed=42)
splits = data_splitter(idmap)

np.testing.assert_equal(splits, [[1, 3], [0], [2]])



Split ids based on predefined splits.


splits: The predefined splits.


Split data into three pre-defined groups.

idmap = IDMap(["file1", "file2", "file3", "file4"])
presplits = [["file4", "file3"], ["file2"], ["file1"]]

data_splitter = FixedSplitter(presplits)
splits = data_splitter(idmap=idmap)

assert splits == [[3, 2], [1], [0]]


SingleSplitSplitter*args, **kwargs)

Return all items in a single group, without shuffling.