mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-22 02:23:48 +00:00
245 lines
7.9 KiB
Python
245 lines
7.9 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import hashlib
|
|
import warnings
|
|
|
|
import pandas as pd
|
|
from toolz import curried
|
|
from typing import Callable
|
|
|
|
from .core import sanitize_dataframe
|
|
from .core import sanitize_geo_interface
|
|
from .deprecation import AltairDeprecationWarning
|
|
from .plugin_registry import PluginRegistry
|
|
|
|
|
|
# ==============================================================================
|
|
# Data transformer registry
|
|
# ==============================================================================
|
|
DataTransformerType = Callable
|
|
|
|
|
|
class DataTransformerRegistry(PluginRegistry[DataTransformerType]):
|
|
_global_settings = {"consolidate_datasets": True}
|
|
|
|
@property
|
|
def consolidate_datasets(self):
|
|
return self._global_settings["consolidate_datasets"]
|
|
|
|
@consolidate_datasets.setter
|
|
def consolidate_datasets(self, value):
|
|
self._global_settings["consolidate_datasets"] = value
|
|
|
|
|
|
# ==============================================================================
|
|
# Data model transformers
|
|
#
|
|
# A data model transformer is a pure function that takes a dict or DataFrame
|
|
# and returns a transformed version of a dict or DataFrame. The dict objects
|
|
# will be the Data portion of the VegaLite schema. The idea is that user can
|
|
# pipe a sequence of these data transformers together to prepare the data before
|
|
# it hits the renderer.
|
|
#
|
|
# In this version of Altair, renderers only deal with the dict form of a
|
|
# VegaLite spec, after the Data model has been put into a schema compliant
|
|
# form.
|
|
#
|
|
# A data model transformer has the following type signature:
|
|
# DataModelType = Union[dict, pd.DataFrame]
|
|
# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType]
|
|
# ==============================================================================
|
|
|
|
|
|
class MaxRowsError(Exception):
|
|
"""Raised when a data model has too many rows."""
|
|
|
|
pass
|
|
|
|
|
|
@curried.curry
|
|
def limit_rows(data, max_rows=5000):
|
|
"""Raise MaxRowsError if the data model has more than max_rows.
|
|
|
|
If max_rows is None, then do not perform any check.
|
|
"""
|
|
check_data_type(data)
|
|
if hasattr(data, "__geo_interface__"):
|
|
if data.__geo_interface__["type"] == "FeatureCollection":
|
|
values = data.__geo_interface__["features"]
|
|
else:
|
|
values = data.__geo_interface__
|
|
elif isinstance(data, pd.DataFrame):
|
|
values = data
|
|
elif isinstance(data, dict):
|
|
if "values" in data:
|
|
values = data["values"]
|
|
else:
|
|
return data
|
|
if max_rows is not None and len(values) > max_rows:
|
|
raise MaxRowsError(
|
|
"The number of rows in your dataset is greater "
|
|
"than the maximum allowed ({}). "
|
|
"For information on how to plot larger datasets "
|
|
"in Altair, see the documentation".format(max_rows)
|
|
)
|
|
return data
|
|
|
|
|
|
@curried.curry
|
|
def sample(data, n=None, frac=None):
|
|
"""Reduce the size of the data model by sampling without replacement."""
|
|
check_data_type(data)
|
|
if isinstance(data, pd.DataFrame):
|
|
return data.sample(n=n, frac=frac)
|
|
elif isinstance(data, dict):
|
|
if "values" in data:
|
|
values = data["values"]
|
|
n = n if n else int(frac * len(values))
|
|
values = random.sample(values, n)
|
|
return {"values": values}
|
|
|
|
|
|
@curried.curry
|
|
def to_json(
|
|
data,
|
|
prefix="altair-data",
|
|
extension="json",
|
|
filename="{prefix}-{hash}.{extension}",
|
|
urlpath="",
|
|
):
|
|
"""
|
|
Write the data model to a .json file and return a url based data model.
|
|
"""
|
|
data_json = _data_to_json_string(data)
|
|
data_hash = _compute_data_hash(data_json)
|
|
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
|
|
with open(filename, "w") as f:
|
|
f.write(data_json)
|
|
return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}}
|
|
|
|
|
|
@curried.curry
|
|
def to_csv(
|
|
data,
|
|
prefix="altair-data",
|
|
extension="csv",
|
|
filename="{prefix}-{hash}.{extension}",
|
|
urlpath="",
|
|
):
|
|
"""Write the data model to a .csv file and return a url based data model."""
|
|
data_csv = _data_to_csv_string(data)
|
|
data_hash = _compute_data_hash(data_csv)
|
|
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
|
|
with open(filename, "w") as f:
|
|
f.write(data_csv)
|
|
return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}}
|
|
|
|
|
|
@curried.curry
|
|
def to_values(data):
|
|
"""Replace a DataFrame by a data model with values."""
|
|
check_data_type(data)
|
|
if hasattr(data, "__geo_interface__"):
|
|
if isinstance(data, pd.DataFrame):
|
|
data = sanitize_dataframe(data)
|
|
data = sanitize_geo_interface(data.__geo_interface__)
|
|
return {"values": data}
|
|
elif isinstance(data, pd.DataFrame):
|
|
data = sanitize_dataframe(data)
|
|
return {"values": data.to_dict(orient="records")}
|
|
elif isinstance(data, dict):
|
|
if "values" not in data:
|
|
raise KeyError("values expected in data dict, but not present.")
|
|
return data
|
|
|
|
|
|
def check_data_type(data):
|
|
"""Raise if the data is not a dict or DataFrame."""
|
|
if not isinstance(data, (dict, pd.DataFrame)) and not hasattr(
|
|
data, "__geo_interface__"
|
|
):
|
|
raise TypeError(
|
|
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format(
|
|
type(data)
|
|
)
|
|
)
|
|
|
|
|
|
# ==============================================================================
|
|
# Private utilities
|
|
# ==============================================================================
|
|
|
|
|
|
def _compute_data_hash(data_str):
|
|
return hashlib.md5(data_str.encode()).hexdigest()
|
|
|
|
|
|
def _data_to_json_string(data):
|
|
"""Return a JSON string representation of the input data"""
|
|
check_data_type(data)
|
|
if hasattr(data, "__geo_interface__"):
|
|
if isinstance(data, pd.DataFrame):
|
|
data = sanitize_dataframe(data)
|
|
data = sanitize_geo_interface(data.__geo_interface__)
|
|
return json.dumps(data)
|
|
elif isinstance(data, pd.DataFrame):
|
|
data = sanitize_dataframe(data)
|
|
return data.to_json(orient="records", double_precision=15)
|
|
elif isinstance(data, dict):
|
|
if "values" not in data:
|
|
raise KeyError("values expected in data dict, but not present.")
|
|
return json.dumps(data["values"], sort_keys=True)
|
|
else:
|
|
raise NotImplementedError(
|
|
"to_json only works with data expressed as " "a DataFrame or as a dict"
|
|
)
|
|
|
|
|
|
def _data_to_csv_string(data):
|
|
"""return a CSV string representation of the input data"""
|
|
check_data_type(data)
|
|
if hasattr(data, "__geo_interface__"):
|
|
raise NotImplementedError(
|
|
"to_csv does not work with data that "
|
|
"contains the __geo_interface__ attribute"
|
|
)
|
|
elif isinstance(data, pd.DataFrame):
|
|
data = sanitize_dataframe(data)
|
|
return data.to_csv(index=False)
|
|
elif isinstance(data, dict):
|
|
if "values" not in data:
|
|
raise KeyError("values expected in data dict, but not present")
|
|
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
|
|
else:
|
|
raise NotImplementedError(
|
|
"to_csv only works with data expressed as " "a DataFrame or as a dict"
|
|
)
|
|
|
|
|
|
def pipe(data, *funcs):
|
|
"""
|
|
Pipe a value through a sequence of functions
|
|
|
|
Deprecated: use toolz.curried.pipe() instead.
|
|
"""
|
|
warnings.warn(
|
|
"alt.pipe() is deprecated, and will be removed in a future release. "
|
|
"Use toolz.curried.pipe() instead.",
|
|
AltairDeprecationWarning,
|
|
)
|
|
return curried.pipe(data, *funcs)
|
|
|
|
|
|
def curry(*args, **kwargs):
|
|
"""Curry a callable function
|
|
|
|
Deprecated: use toolz.curried.curry() instead.
|
|
"""
|
|
warnings.warn(
|
|
"alt.curry() is deprecated, and will be removed in a future release. "
|
|
"Use toolz.curried.curry() instead.",
|
|
AltairDeprecationWarning,
|
|
)
|
|
return curried.curry(*args, **kwargs)
|