Datasets

Loading and iterating typed WebDataset tar files

The Dataset class provides typed iteration over WebDataset tar files with automatic batching and lens transformations.

Creating a Dataset

import atdata
from numpy.typing import NDArray

@atdata.packable
class ImageSample:
    image: NDArray
    label: str

# Single shard (string URL - most common)
dataset = atdata.Dataset[ImageSample]("data-000000.tar")

# Multiple shards with brace notation
dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")

The type parameter [ImageSample] specifies what sample type the dataset contains. This enables type-safe iteration and automatic deserialization.

Data Sources

Datasets can be created from different data sources using the DataSource protocol:

URL Source (default)

When you pass a string to Dataset, it automatically wraps it in a URLSource:

# These are equivalent:
dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
dataset = atdata.Dataset[ImageSample](atdata.URLSource("data-{000000..000009}.tar"))

S3 Source

For private S3 buckets or S3-compatible storage (Cloudflare R2, MinIO), use S3Source:

# From explicit credentials
source = atdata.S3Source(
    bucket="my-bucket",
    keys=["data-000000.tar", "data-000001.tar"],
    endpoint="https://my-r2-account.r2.cloudflarestorage.com",
    access_key="AKID...",
    secret_key="SECRET...",
)
dataset = atdata.Dataset[ImageSample](source)

# From S3 URLs
source = atdata.S3Source.from_urls([
    "s3://my-bucket/data-000000.tar",
    "s3://my-bucket/data-000001.tar",
])
dataset = atdata.Dataset[ImageSample](source)
Note

S3Source uses boto3 for streaming, enabling authentication with private buckets. For public S3 URLs, a string URL with URLSource works directly.

Iteration Modes

Ordered Iteration

Iterate through samples in their original order:

# With batching (default batch_size=1)
for batch in dataset.ordered(batch_size=32):
    images = batch.image  # numpy array (32, H, W, C)
    labels = batch.label  # list of 32 strings

# Without batching (raw samples)
for sample in dataset.ordered(batch_size=None):
    print(sample.label)

Shuffled Iteration

Iterate with randomized order at both shard and sample levels:

for batch in dataset.shuffled(batch_size=32):
    # Samples are shuffled
    process(batch)

# Control shuffle buffer sizes
for batch in dataset.shuffled(
    buffer_shards=100,    # Shards to buffer (default: 100)
    buffer_samples=10000, # Samples to buffer (default: 10,000)
    batch_size=32,
):
    process(batch)
Tip

Larger buffer sizes increase randomness but use more memory. For training, buffer_samples=10000 is usually a good balance.

SampleBatch

When iterating with a batch_size, each iteration yields a SampleBatch with automatic attribute aggregation.

@atdata.packable
class Sample:
    features: NDArray  # shape (256,)
    label: str
    score: float

for batch in dataset.ordered(batch_size=16):
    # NDArray fields are stacked with a batch dimension
    features = batch.features  # numpy array (16, 256)

    # Other fields become lists
    labels = batch.label       # list of 16 strings
    scores = batch.score       # list of 16 floats

Results are cached, so accessing the same attribute multiple times is efficient.

Type Transformations with Lenses

View a dataset through a different sample type using registered lenses:

@atdata.packable
class SimplifiedSample:
    label: str

@atdata.lens
def simplify(src: ImageSample) -> SimplifiedSample:
    return SimplifiedSample(label=src.label)

# Transform dataset to different type
simple_ds = dataset.as_type(SimplifiedSample)

for batch in simple_ds.ordered(batch_size=16):
    print(batch.label)  # Only label field available

See Lenses for details on defining transformations.

Dataset Properties

Shard List

Get the list of individual tar files:

dataset = atdata.Dataset[Sample]("data-{000000..000009}.tar")
shards = dataset.shard_list
# ['data-000000.tar', 'data-000001.tar', ..., 'data-000009.tar']

Metadata

Datasets can have associated metadata from a URL:

dataset = atdata.Dataset[Sample](
    "data-{000000..000009}.tar",
    metadata_url="https://example.com/metadata.msgpack"
)

# Fetched and cached on first access
metadata = dataset.metadata  # dict or None

Writing Datasets

Use WebDataset’s TarWriter or ShardWriter to create datasets:

import webdataset as wds
import numpy as np

samples = [
    ImageSample(image=np.random.rand(224, 224, 3).astype(np.float32), label="cat")
    for _ in range(100)
]

# Single tar file
with wds.writer.TarWriter("data-000000.tar") as sink:
    for i, sample in enumerate(samples):
        sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})

# Multiple shards with automatic splitting
with wds.writer.ShardWriter("data-%06d.tar", maxcount=1000) as sink:
    for i, sample in enumerate(samples):
        sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})

Parquet Export

Export dataset contents to parquet format:

# Export entire dataset
dataset.to_parquet("output.parquet")

# Export with custom field mapping
def extract_fields(sample):
    return {"label": sample.label, "score": sample.confidence}

dataset.to_parquet("output.parquet", sample_map=extract_fields)

# Export in segments
dataset.to_parquet("output.parquet", maxcount=10000)
# Creates output-000000.parquet, output-000001.parquet, etc.

URL Formats

When using string URLs (via URLSource), WebDataset supports various formats:

Format Example
Local files ./data/file.tar, /absolute/path/file-{000000..000009}.tar
HTTP/HTTPS https://example.com/data-{000000..000009}.tar
Google Cloud gs://bucket/path/file.tar

For S3 with authentication, use S3Source instead of s3:// URLs.

Dataset Properties

Source

Access the underlying DataSource:

dataset = atdata.Dataset[Sample]("data.tar")
source = dataset.source  # URLSource instance
print(source.shard_list)  # ['data.tar']

Sample Type

Get the type parameter used to create the dataset:

dataset = atdata.Dataset[ImageSample]("data.tar")
print(dataset.sample_type)  # <class 'ImageSample'>
print(dataset.batch_type)   # SampleBatch[ImageSample]