import atdata
from numpy.typing import NDArray
@atdata.packable
class ImageSample:
image: NDArray
label: str
# Single shard (string URL - most common)
dataset = atdata.Dataset[ImageSample]("data-000000.tar")
# Multiple shards with brace notation
dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")Datasets
The Dataset class provides typed iteration over WebDataset tar files with automatic batching and lens transformations.
Creating a Dataset
The type parameter [ImageSample] specifies what sample type the dataset contains. This enables type-safe iteration and automatic deserialization.
Data Sources
Datasets can be created from different data sources using the DataSource protocol:
URL Source (default)
When you pass a string to Dataset, it automatically wraps it in a URLSource:
# These are equivalent:
dataset = atdata.Dataset[ImageSample]("data-{000000..000009}.tar")
dataset = atdata.Dataset[ImageSample](atdata.URLSource("data-{000000..000009}.tar"))S3 Source
For private S3 buckets or S3-compatible storage (Cloudflare R2, MinIO), use S3Source:
# From explicit credentials
source = atdata.S3Source(
bucket="my-bucket",
keys=["data-000000.tar", "data-000001.tar"],
endpoint="https://my-r2-account.r2.cloudflarestorage.com",
access_key="AKID...",
secret_key="SECRET...",
)
dataset = atdata.Dataset[ImageSample](source)
# From S3 URLs
source = atdata.S3Source.from_urls([
"s3://my-bucket/data-000000.tar",
"s3://my-bucket/data-000001.tar",
])
dataset = atdata.Dataset[ImageSample](source)S3Source uses boto3 for streaming, enabling authentication with private buckets. For public S3 URLs, a string URL with URLSource works directly.
Iteration Modes
Ordered Iteration
Iterate through samples in their original order:
# With batching (default batch_size=1)
for batch in dataset.ordered(batch_size=32):
images = batch.image # numpy array (32, H, W, C)
labels = batch.label # list of 32 strings
# Without batching (raw samples)
for sample in dataset.ordered(batch_size=None):
print(sample.label)Shuffled Iteration
Iterate with randomized order at both shard and sample levels:
for batch in dataset.shuffled(batch_size=32):
# Samples are shuffled
process(batch)
# Control shuffle buffer sizes
for batch in dataset.shuffled(
buffer_shards=100, # Shards to buffer (default: 100)
buffer_samples=10000, # Samples to buffer (default: 10,000)
batch_size=32,
):
process(batch)Larger buffer sizes increase randomness but use more memory. For training, buffer_samples=10000 is usually a good balance.
SampleBatch
When iterating with a batch_size, each iteration yields a SampleBatch with automatic attribute aggregation.
@atdata.packable
class Sample:
features: NDArray # shape (256,)
label: str
score: float
for batch in dataset.ordered(batch_size=16):
# NDArray fields are stacked with a batch dimension
features = batch.features # numpy array (16, 256)
# Other fields become lists
labels = batch.label # list of 16 strings
scores = batch.score # list of 16 floatsResults are cached, so accessing the same attribute multiple times is efficient.
Type Transformations with Lenses
View a dataset through a different sample type using registered lenses:
@atdata.packable
class SimplifiedSample:
label: str
@atdata.lens
def simplify(src: ImageSample) -> SimplifiedSample:
return SimplifiedSample(label=src.label)
# Transform dataset to different type
simple_ds = dataset.as_type(SimplifiedSample)
for batch in simple_ds.ordered(batch_size=16):
print(batch.label) # Only label field availableSee Lenses for details on defining transformations.
Dataset Properties
Metadata
Datasets can have associated metadata from a URL:
dataset = atdata.Dataset[Sample](
"data-{000000..000009}.tar",
metadata_url="https://example.com/metadata.msgpack"
)
# Fetched and cached on first access
metadata = dataset.metadata # dict or NoneWriting Datasets
Use WebDataset’s TarWriter or ShardWriter to create datasets:
import webdataset as wds
import numpy as np
samples = [
ImageSample(image=np.random.rand(224, 224, 3).astype(np.float32), label="cat")
for _ in range(100)
]
# Single tar file
with wds.writer.TarWriter("data-000000.tar") as sink:
for i, sample in enumerate(samples):
sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})
# Multiple shards with automatic splitting
with wds.writer.ShardWriter("data-%06d.tar", maxcount=1000) as sink:
for i, sample in enumerate(samples):
sink.write({**sample.as_wds, "__key__": f"sample_{i:06d}"})Parquet Export
Export dataset contents to parquet format:
# Export entire dataset
dataset.to_parquet("output.parquet")
# Export with custom field mapping
def extract_fields(sample):
return {"label": sample.label, "score": sample.confidence}
dataset.to_parquet("output.parquet", sample_map=extract_fields)
# Export in segments
dataset.to_parquet("output.parquet", maxcount=10000)
# Creates output-000000.parquet, output-000001.parquet, etc.URL Formats
When using string URLs (via URLSource), WebDataset supports various formats:
| Format | Example |
|---|---|
| Local files | ./data/file.tar, /absolute/path/file-{000000..000009}.tar |
| HTTP/HTTPS | https://example.com/data-{000000..000009}.tar |
| Google Cloud | gs://bucket/path/file.tar |
For S3 with authentication, use S3Source instead of s3:// URLs.
Dataset Properties
Source
Access the underlying DataSource:
dataset = atdata.Dataset[Sample]("data.tar")
source = dataset.source # URLSource instance
print(source.shard_list) # ['data.tar']Sample Type
Get the type parameter used to create the dataset:
dataset = atdata.Dataset[ImageSample]("data.tar")
print(dataset.sample_type) # <class 'ImageSample'>
print(dataset.batch_type) # SampleBatch[ImageSample]