Skip to content

Add DatasetDict.to_pandas #5312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ It also has dataset transform methods like map or filter, to process all the spl
- from_parquet
- from_text
- prepare_for_task
- to_pandas

<a id='package_reference_features'></a>

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4194,7 +4194,7 @@ def to_pandas(
Example:

```py
>>> ds.to_pandas()
>>> df = ds.to_pandas()
```
"""
if not batched:
Expand Down
65 changes: 63 additions & 2 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import warnings
from io import BytesIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union

import fsspec
import numpy as np
import pandas as pd
from huggingface_hub import HfApi
from typing_extensions import Literal

from datasets.utils.metadata import DatasetMetadata

Expand All @@ -36,7 +38,11 @@
logger = logging.get_logger(__name__)


class DatasetDict(dict):
class SplitsError(ValueError):
pass


class DatasetDict(Dict[str, Dataset]):
"""A dictionary (dict of str: datasets.Dataset) with dataset transforms methods (map, filter, etc.)"""

def _check_values_type(self):
Expand Down Expand Up @@ -1417,6 +1423,61 @@ def push_to_hub(
revision=branch,
)

def to_pandas(
self,
splits: Optional[Union[Literal["all"], List[str]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also add our Split to the type hint.

df_all = dataset_dict.to_pandas(splits=Split.ALL)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also allow str?

df_test = dataset_dict.to_pandas(splits="test")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one wants to choose one split they already do dataset_dict["test"].to_pandas() - I don't think that introducing splits="test" would make it particularly easier.

Although since we don't support the Split API fully (e.g. doing "train+test[:20]") I wouldn't necessarily add Split in the type hint

batch_size: Optional[int] = None,
batched: bool = False,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
"""Returns the dataset as a :class:`pandas.DataFrame`. Can also return a generator for large datasets.

You must specify which splits to convert if the dataset is made of multiple splits.

Args:
splits (:obj:`Union[Literal["all"], List[str]]`, optional): List of splits to convert to a DataFrame.
You don't need to specify the splits if there's only one.
Use splits="all" to convert all the splits (they will be converted in the order of the dictionary).
batched (:obj:`bool`): Set to :obj:`True` to return a generator that yields the dataset as batches
of ``batch_size`` rows. Defaults to :obj:`False` (returns the whole datasets once)
batch_size (:obj:`int`, optional): The size (number of rows) of the batches if ``batched`` is `True`.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.

Returns:
`pandas.DataFrame` or `Iterator[pandas.DataFrame]`

Example:

If the dataset has one split:
```py
>>> df = dataset_dict.to_pandas()
```

If the dataset has multiple splits:
```py
>>> df_train = dataset_dict["train"].to_pandas()
>>> df_all = dataset_dict.to_pandas(splits="all")
>>> df_train_test = dataset_dict.to_pandas(splits=["train", "test"])
```
"""
self._check_values_type()
self._check_values_features()
if splits is None and len(self) > 1:
raise SplitsError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe invent a more specific name for this type of error? smth like SplitsNotSpecifiedError/SplitsNotProvidedError ? (subclassing SplitsError?)

"Failed to convert to pandas: please choose which splits to convert. "
f"Available splits: {list(self)}. For example:"
'\n df = ds["train"].to_pandas()'
'\n df = ds.to_pandas(splits=["train", "test"])'
'\n df = ds.to_pandas(splits="all")'
)
splits = splits if splits is not None and splits != "all" else list(self)
bad_splits = list(set(splits) - set(self))
if bad_splits:
raise ValueError(f"Can't convert those splits to pandas : {bad_splits}. Available splits: {list(self)}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe raise a custom error here too? to be aligned with UnexpectedSplits exception in info_utils.py:

class UnexpectedSplits(SplitsVerificationException):

(subclassing SplitsError defined above?)

if batched:
return (df for split in splits for df in self[split].to_pandas(batch_size=batch_size, batched=batched))
else:
return pd.concat([self[split].to_pandas() for split in splits])


class IterableDatasetDict(dict):
def with_format(
Expand Down
45 changes: 44 additions & 1 deletion tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from datasets import load_from_disk
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict
from datasets.dataset_dict import DatasetDict, SplitsError
from datasets.features import ClassLabel, Features, Sequence, Value
from datasets.splits import NamedSplit

Expand Down Expand Up @@ -686,3 +686,46 @@ def test_dummy_dataset_serialize_s3(s3, dataset, s3_test_bucket_name):
assert [len(dset) for dset in dsets.values()] == lengths
assert dsets["train"].column_names == column_names
assert dsets["test"].column_names == column_names


def test_datasetdict_to_pandas():
dsets = DatasetDict(
{
"train": Dataset.from_dict({"foo": ["hello", "there"], "bar": [0, 1]}),
}
)
df = dsets.to_pandas()
assert df.shape == (2, 2)
assert list(df["foo"]) == ["hello", "there"]
assert list(df["bar"]) == [0, 1]

# multiple splits
dsets = DatasetDict(
{
"train": Dataset.from_dict({"foo": ["hello", "there"], "bar": [0, 1]}),
"test": Dataset.from_dict({"foo": ["general", "kenobi"], "bar": [2, 3]}),
}
)
with pytest.raises(SplitsError):
df = dsets.to_pandas()
df = dsets.to_pandas(splits=["train", "test"])
assert df.shape == (4, 2)
assert list(df["foo"]) == ["hello", "there", "general", "kenobi"]
assert list(df["bar"]) == [0, 1, 2, 3]
df = dsets.to_pandas(splits="all")
assert df.shape == (4, 2)
assert list(df["foo"]) == ["hello", "there", "general", "kenobi"]
assert list(df["bar"]) == [0, 1, 2, 3]

# batched
dsets = DatasetDict(
{
"train": Dataset.from_dict({"foo": range(42)}),
}
)
for i, df in enumerate(dsets.to_pandas(batched=True, batch_size=10)):
if i == 4: # last batch
assert df.shape == (2, 1)
else: # batch size of 10
assert df.shape == (10, 1)
assert i == 4 # total of 4 batches