load_dataset API

HuggingFace-style dataset loading interface

The load_dataset() function provides a HuggingFace Datasets-style interface for loading typed datasets.

Overview

Key differences from HuggingFace Datasets:

  • Requires explicit sample_type parameter (typed dataclass) unless using index
  • Returns atdata.Dataset[ST] instead of HF Dataset
  • Built on WebDataset for efficient streaming
  • No Arrow caching layer

Basic Usage

import atdata
from atdata import load_dataset
from numpy.typing import NDArray

@atdata.packable
class TextSample:
    text: str
    label: int

# Load a specific split
train_ds = load_dataset("path/to/data.tar", TextSample, split="train")

# Load all splits (returns DatasetDict)
ds_dict = load_dataset("path/to/data/", TextSample)
train_ds = ds_dict["train"]
test_ds = ds_dict["test"]

Path Formats

WebDataset Brace Notation

# Range notation
ds = load_dataset("data-{000000..000099}.tar", MySample, split="train")

# List notation
ds = load_dataset("data-{train,test,val}.tar", MySample, split="train")

Glob Patterns

# Match all tar files
ds = load_dataset("path/to/*.tar", MySample)

# Match pattern
ds = load_dataset("path/to/train-*.tar", MySample, split="train")

Local Directory

# Scans for .tar files
ds = load_dataset("./my-dataset/", MySample)

Remote URLs

# S3 (public buckets)
ds = load_dataset("s3://bucket/data-{000..099}.tar", MySample, split="train")

# HTTP/HTTPS
ds = load_dataset("https://example.com/data.tar", MySample, split="train")

# Google Cloud Storage
ds = load_dataset("gs://bucket/data.tar", MySample, split="train")
Note

For private S3 buckets or S3-compatible storage with authentication, use atdata.S3Source with Dataset directly. See Datasets for details.

Index Lookup

from atdata.local import LocalIndex

index = LocalIndex()

# Load from local index (auto-resolves type from schema)
ds = load_dataset("@local/my-dataset", index=index, split="train")

# With explicit type
ds = load_dataset("@local/my-dataset", MySample, index=index, split="train")

Split Detection

Splits are automatically detected from filenames and directories:

Pattern Detected Split
train-*.tar, training-*.tar train
test-*.tar, testing-*.tar test
val-*.tar, valid-*.tar, validation-*.tar validation
dev-*.tar, development-*.tar validation
train/*.tar (directory) train
test/*.tar (directory) test
Note

Files without a detected split default to “train”.

DatasetDict

When loading without split=, returns a DatasetDict:

ds_dict = load_dataset("path/to/data/", MySample)

# Access splits
train_ds = ds_dict["train"]
test_ds = ds_dict["test"]

# Iterate splits
for name, dataset in ds_dict.items():
    print(f"{name}: {len(dataset.shard_list)} shards")

# Properties
print(ds_dict.num_shards)    # {'train': 10, 'test': 2}
print(ds_dict.sample_type)   # <class 'MySample'>
print(ds_dict.streaming)     # False

Explicit Data Files

Override automatic detection with data_files:

# Single pattern
ds = load_dataset(
    "path/to/",
    MySample,
    data_files="custom-*.tar",
)

# List of patterns
ds = load_dataset(
    "path/to/",
    MySample,
    data_files=["shard-000.tar", "shard-001.tar"],
)

# Explicit split mapping
ds = load_dataset(
    "path/to/",
    MySample,
    data_files={
        "train": "training-shards-*.tar",
        "test": "eval-data.tar",
    },
)

Streaming Mode

The streaming parameter signals intent for streaming mode:

# Mark as streaming
ds_dict = load_dataset("path/to/data.tar", MySample, streaming=True)

# Check streaming status
if ds_dict.streaming:
    print("Streaming mode")
Tip

atdata datasets are always lazy/streaming via WebDataset pipelines. This parameter primarily signals intent.

Auto Type Resolution

When using index lookup, the sample type can be resolved automatically:

from atdata.local import LocalIndex

index = LocalIndex()

# No sample_type needed - resolved from schema
ds = load_dataset("@local/my-dataset", index=index, split="train")

# Type is inferred from the stored schema
sample_type = ds.sample_type

Error Handling

try:
    ds = load_dataset("path/to/data.tar", MySample, split="train")
except FileNotFoundError:
    print("No data files found")
except ValueError as e:
    if "Split" in str(e):
        print("Requested split not found")
    else:
        print(f"Invalid configuration: {e}")
except KeyError:
    print("Dataset not found in index")

Complete Example

import numpy as np
from numpy.typing import NDArray
import atdata
from atdata import load_dataset
import webdataset as wds

# 1. Define sample type
@atdata.packable
class ImageSample:
    image: NDArray
    label: str

# 2. Create dataset files
for split in ["train", "test"]:
    with wds.writer.TarWriter(f"{split}-000.tar") as sink:
        for i in range(100):
            sample = ImageSample(
                image=np.random.rand(64, 64, 3).astype(np.float32),
                label=f"sample_{i}",
            )
            sink.write({**sample.as_wds, "__key__": f"{i:06d}"})

# 3. Load with split detection
ds_dict = load_dataset("./", ImageSample)
print(ds_dict.keys())  # dict_keys(['train', 'test'])

# 4. Iterate
for batch in ds_dict["train"].ordered(batch_size=16):
    print(batch.image.shape)  # (16, 64, 64, 3)
    print(batch.label)        # ['sample_0', 'sample_1', ...]
    break

# 5. Load specific split
train_ds = load_dataset("./", ImageSample, split="train")
for batch in train_ds.ordered(batch_size=32):
    process(batch)