2022-05-23 00:16:32 +04:00

245 lines
7.9 KiB
Python

import json
import os
import random
import hashlib
import warnings
import pandas as pd
from toolz import curried
from typing import Callable
from .core import sanitize_dataframe
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
# ==============================================================================
# Data transformer registry
# ==============================================================================
DataTransformerType = Callable
class DataTransformerRegistry(PluginRegistry[DataTransformerType]):
_global_settings = {"consolidate_datasets": True}
@property
def consolidate_datasets(self):
return self._global_settings["consolidate_datasets"]
@consolidate_datasets.setter
def consolidate_datasets(self, value):
self._global_settings["consolidate_datasets"] = value
# ==============================================================================
# Data model transformers
#
# A data model transformer is a pure function that takes a dict or DataFrame
# and returns a transformed version of a dict or DataFrame. The dict objects
# will be the Data portion of the VegaLite schema. The idea is that user can
# pipe a sequence of these data transformers together to prepare the data before
# it hits the renderer.
#
# In this version of Altair, renderers only deal with the dict form of a
# VegaLite spec, after the Data model has been put into a schema compliant
# form.
#
# A data model transformer has the following type signature:
# DataModelType = Union[dict, pd.DataFrame]
# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType]
# ==============================================================================
class MaxRowsError(Exception):
"""Raised when a data model has too many rows."""
pass
@curried.curry
def limit_rows(data, max_rows=5000):
"""Raise MaxRowsError if the data model has more than max_rows.
If max_rows is None, then do not perform any check.
"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if data.__geo_interface__["type"] == "FeatureCollection":
values = data.__geo_interface__["features"]
else:
values = data.__geo_interface__
elif isinstance(data, pd.DataFrame):
values = data
elif isinstance(data, dict):
if "values" in data:
values = data["values"]
else:
return data
if max_rows is not None and len(values) > max_rows:
raise MaxRowsError(
"The number of rows in your dataset is greater "
"than the maximum allowed ({}). "
"For information on how to plot larger datasets "
"in Altair, see the documentation".format(max_rows)
)
return data
@curried.curry
def sample(data, n=None, frac=None):
"""Reduce the size of the data model by sampling without replacement."""
check_data_type(data)
if isinstance(data, pd.DataFrame):
return data.sample(n=n, frac=frac)
elif isinstance(data, dict):
if "values" in data:
values = data["values"]
n = n if n else int(frac * len(values))
values = random.sample(values, n)
return {"values": values}
@curried.curry
def to_json(
data,
prefix="altair-data",
extension="json",
filename="{prefix}-{hash}.{extension}",
urlpath="",
):
"""
Write the data model to a .json file and return a url based data model.
"""
data_json = _data_to_json_string(data)
data_hash = _compute_data_hash(data_json)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
with open(filename, "w") as f:
f.write(data_json)
return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}}
@curried.curry
def to_csv(
data,
prefix="altair-data",
extension="csv",
filename="{prefix}-{hash}.{extension}",
urlpath="",
):
"""Write the data model to a .csv file and return a url based data model."""
data_csv = _data_to_csv_string(data)
data_hash = _compute_data_hash(data_csv)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
with open(filename, "w") as f:
f.write(data_csv)
return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}}
@curried.curry
def to_values(data):
"""Replace a DataFrame by a data model with values."""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
data = sanitize_geo_interface(data.__geo_interface__)
return {"values": data}
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return {"values": data.to_dict(orient="records")}
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return data
def check_data_type(data):
"""Raise if the data is not a dict or DataFrame."""
if not isinstance(data, (dict, pd.DataFrame)) and not hasattr(
data, "__geo_interface__"
):
raise TypeError(
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format(
type(data)
)
)
# ==============================================================================
# Private utilities
# ==============================================================================
def _compute_data_hash(data_str):
return hashlib.md5(data_str.encode()).hexdigest()
def _data_to_json_string(data):
"""Return a JSON string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
data = sanitize_geo_interface(data.__geo_interface__)
return json.dumps(data)
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return data.to_json(orient="records", double_precision=15)
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return json.dumps(data["values"], sort_keys=True)
else:
raise NotImplementedError(
"to_json only works with data expressed as " "a DataFrame or as a dict"
)
def _data_to_csv_string(data):
"""return a CSV string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
raise NotImplementedError(
"to_csv does not work with data that "
"contains the __geo_interface__ attribute"
)
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return data.to_csv(index=False)
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present")
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
else:
raise NotImplementedError(
"to_csv only works with data expressed as " "a DataFrame or as a dict"
)
def pipe(data, *funcs):
"""
Pipe a value through a sequence of functions
Deprecated: use toolz.curried.pipe() instead.
"""
warnings.warn(
"alt.pipe() is deprecated, and will be removed in a future release. "
"Use toolz.curried.pipe() instead.",
AltairDeprecationWarning,
)
return curried.pipe(data, *funcs)
def curry(*args, **kwargs):
"""Curry a callable function
Deprecated: use toolz.curried.curry() instead.
"""
warnings.warn(
"alt.curry() is deprecated, and will be removed in a future release. "
"Use toolz.curried.curry() instead.",
AltairDeprecationWarning,
)
return curried.curry(*args, **kwargs)