first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
from .core import (
infer_vegalite_type,
infer_encoding_types,
sanitize_dataframe,
parse_shorthand,
use_signature,
update_subtraits,
update_nested,
display_traceback,
SchemaBase,
Undefined,
)
from .html import spec_to_html
from .plugin_registry import PluginRegistry
from .deprecation import AltairDeprecationWarning
__all__ = (
"infer_vegalite_type",
"infer_encoding_types",
"sanitize_dataframe",
"spec_to_html",
"parse_shorthand",
"use_signature",
"update_subtraits",
"update_nested",
"display_traceback",
"AltairDeprecationWarning",
"SchemaBase",
"Undefined",
"PluginRegistry",
)

View File

@@ -0,0 +1,731 @@
"""
Utility routines
"""
from collections.abc import Mapping
from copy import deepcopy
import json
import itertools
import re
import sys
import traceback
import warnings
import jsonschema
import pandas as pd
import numpy as np
from .schemapi import SchemaBase, Undefined
try:
from pandas.api.types import infer_dtype as _infer_dtype
except ImportError:
# Import for pandas < 0.20.0
from pandas.lib import infer_dtype as _infer_dtype
def infer_dtype(value):
"""Infer the dtype of the value.
This is a compatibility function for pandas infer_dtype,
with skipna=False regardless of the pandas version.
"""
if not hasattr(infer_dtype, "_supports_skipna"):
try:
_infer_dtype([1], skipna=False)
except TypeError:
# pandas < 0.21.0 don't support skipna keyword
infer_dtype._supports_skipna = False
else:
infer_dtype._supports_skipna = True
if infer_dtype._supports_skipna:
return _infer_dtype(value, skipna=False)
else:
return _infer_dtype(value)
TYPECODE_MAP = {
"ordinal": "O",
"nominal": "N",
"quantitative": "Q",
"temporal": "T",
"geojson": "G",
}
INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()}
# aggregates from vega-lite version 4.6.0
AGGREGATES = [
"argmax",
"argmin",
"average",
"count",
"distinct",
"max",
"mean",
"median",
"min",
"missing",
"product",
"q1",
"q3",
"ci0",
"ci1",
"stderr",
"stdev",
"stdevp",
"sum",
"valid",
"values",
"variance",
"variancep",
]
# window aggregates from vega-lite version 4.6.0
WINDOW_AGGREGATES = [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"cume_dist",
"ntile",
"lag",
"lead",
"first_value",
"last_value",
"nth_value",
]
# timeUnits from vega-lite version 4.17.0
TIMEUNITS = [
"year",
"quarter",
"month",
"week",
"day",
"dayofyear",
"date",
"hours",
"minutes",
"seconds",
"milliseconds",
"yearquarter",
"yearquartermonth",
"yearmonth",
"yearmonthdate",
"yearmonthdatehours",
"yearmonthdatehoursminutes",
"yearmonthdatehoursminutesseconds",
"yearweek",
"yearweekday",
"yearweekdayhours",
"yearweekdayhoursminutes",
"yearweekdayhoursminutesseconds",
"yeardayofyear",
"quartermonth",
"monthdate",
"monthdatehours",
"monthdatehoursminutes",
"monthdatehoursminutesseconds",
"weekday",
"weeksdayhours",
"weekdayhoursminutes",
"weekdayhoursminutesseconds",
"dayhours",
"dayhoursminutes",
"dayhoursminutesseconds",
"hoursminutes",
"hoursminutesseconds",
"minutesseconds",
"secondsmilliseconds",
"utcyear",
"utcquarter",
"utcmonth",
"utcweek",
"utcday",
"utcdayofyear",
"utcdate",
"utchours",
"utcminutes",
"utcseconds",
"utcmilliseconds",
"utcyearquarter",
"utcyearquartermonth",
"utcyearmonth",
"utcyearmonthdate",
"utcyearmonthdatehours",
"utcyearmonthdatehoursminutes",
"utcyearmonthdatehoursminutesseconds",
"utcyearweek",
"utcyearweekday",
"utcyearweekdayhours",
"utcyearweekdayhoursminutes",
"utcyearweekdayhoursminutesseconds",
"utcyeardayofyear",
"utcquartermonth",
"utcmonthdate",
"utcmonthdatehours",
"utcmonthdatehoursminutes",
"utcmonthdatehoursminutesseconds",
"utcweekday",
"utcweeksdayhours",
"utcweekdayhoursminutes",
"utcweekdayhoursminutesseconds",
"utcdayhours",
"utcdayhoursminutes",
"utcdayhoursminutesseconds",
"utchoursminutes",
"utchoursminutesseconds",
"utcminutesseconds",
"utcsecondsmilliseconds",
]
def infer_vegalite_type(data):
"""
From an array-like input, infer the correct vega typecode
('ordinal', 'nominal', 'quantitative', or 'temporal')
Parameters
----------
data: Numpy array or Pandas Series
"""
# Otherwise, infer based on the dtype of the input
typ = infer_dtype(data)
# TODO: Once this returns 'O', please update test_select_x and test_select_y in test_api.py
if typ in [
"floating",
"mixed-integer-float",
"integer",
"mixed-integer",
"complex",
]:
return "quantitative"
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
return "nominal"
elif typ in [
"datetime",
"datetime64",
"timedelta",
"timedelta64",
"date",
"time",
"period",
]:
return "temporal"
else:
warnings.warn(
"I don't know how to infer vegalite type from '{}'. "
"Defaulting to nominal.".format(typ)
)
return "nominal"
def merge_props_geom(feat):
"""
Merge properties with geometry
* Overwrites 'type' and 'geometry' entries if existing
"""
geom = {k: feat[k] for k in ("type", "geometry")}
try:
feat["properties"].update(geom)
props_geom = feat["properties"]
except (AttributeError, KeyError):
# AttributeError when 'properties' equals None
# KeyError when 'properties' is non-existing
props_geom = geom
return props_geom
def sanitize_geo_interface(geo):
"""Santize a geo_interface to prepare it for serialization.
* Make a copy
* Convert type array or _Array to list
* Convert tuples to lists (using json.loads/dumps)
* Merge properties with geometry
"""
geo = deepcopy(geo)
# convert type _Array or array to list
for key in geo.keys():
if str(type(geo[key]).__name__).startswith(("_Array", "array")):
geo[key] = geo[key].tolist()
# convert (nested) tuples to lists
geo = json.loads(json.dumps(geo))
# sanitize features
if geo["type"] == "FeatureCollection":
geo = geo["features"]
if len(geo) > 0:
for idx, feat in enumerate(geo):
geo[idx] = merge_props_geom(feat)
elif geo["type"] == "Feature":
geo = merge_props_geom(geo)
else:
geo = {"type": "Feature", "geometry": geo}
return geo
def sanitize_dataframe(df): # noqa: C901
"""Sanitize a DataFrame to prepare it for serialization.
* Make a copy
* Convert RangeIndex columns to strings
* Raise ValueError if column names are not strings
* Raise ValueError if it has a hierarchical index.
* Convert categoricals to strings.
* Convert np.bool_ dtypes to Python bool objects
* Convert np.int dtypes to Python int objects
* Convert floats to objects and replace NaNs/infs with None.
* Convert DateTime dtypes into appropriate string representations
* Convert Nullable integers to objects and replace NaN with None
* Convert Nullable boolean to objects and replace NaN with None
* convert dedicated string column to objects and replace NaN with None
* Raise a ValueError for TimeDelta dtypes
"""
df = df.copy()
if isinstance(df.columns, pd.RangeIndex):
df.columns = df.columns.astype(str)
for col in df.columns:
if not isinstance(col, str):
raise ValueError(
"Dataframe contains invalid column name: {0!r}. "
"Column names must be strings".format(col)
)
if isinstance(df.index, pd.MultiIndex):
raise ValueError("Hierarchical indices not supported")
if isinstance(df.columns, pd.MultiIndex):
raise ValueError("Hierarchical indices not supported")
def to_list_if_array(val):
if isinstance(val, np.ndarray):
return val.tolist()
else:
return val
for col_name, dtype in df.dtypes.iteritems():
if str(dtype) == "category":
# XXXX: work around bug in to_json for categorical types
# https://github.com/pydata/pandas/issues/10778
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype) == "string":
# dedicated string datatype (since 1.0)
# https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype) == "bool":
# convert numpy bools to objects; np.bool is not JSON serializable
df[col_name] = df[col_name].astype(object)
elif str(dtype) == "boolean":
# dedicated boolean datatype (since 1.0)
# https://pandas.io/docs/user_guide/boolean.html
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype).startswith("datetime"):
# Convert datetimes to strings. This needs to be a full ISO string
# with time, which is why we cannot use ``col.astype(str)``.
# This is because Javascript parses date-only times in UTC, but
# parses full ISO-8601 dates as local time, and dates in Vega and
# Vega-Lite are displayed in local time by default.
# (see https://github.com/altair-viz/altair/issues/1027)
df[col_name] = (
df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "")
)
elif str(dtype).startswith("timedelta"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=col_name, dtype=dtype)
)
elif str(dtype).startswith("geometry"):
# geopandas >=0.6.1 uses the dtype geometry. Continue here
# otherwise it will give an error on np.issubdtype(dtype, np.integer)
continue
elif str(dtype) in {
"Int8",
"Int16",
"Int32",
"Int64",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"Float32",
"Float64",
}: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0)
# https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif np.issubdtype(dtype, np.integer):
# convert integers to objects; np.int is not JSON serializable
df[col_name] = df[col_name].astype(object)
elif np.issubdtype(dtype, np.floating):
# For floats, convert to Python float: np.float is not JSON serializable
# Also convert NaN/inf values to null, as they are not JSON serializable
col = df[col_name]
bad_values = col.isnull() | np.isinf(col)
df[col_name] = col.astype(object).where(~bad_values, None)
elif dtype == object:
# Convert numpy arrays saved as objects to lists
# Arrays are not JSON serializable
col = df[col_name].apply(to_list_if_array, convert_dtype=False)
df[col_name] = col.where(col.notnull(), None)
return df
def parse_shorthand(
shorthand,
data=None,
parse_aggregates=True,
parse_window_ops=False,
parse_timeunits=True,
parse_types=True,
):
"""General tool to parse shorthand values
These are of the form:
- "col_name"
- "col_name:O"
- "average(col_name)"
- "average(col_name):O"
Optionally, a dataframe may be supplied, from which the type
will be inferred if not specified in the shorthand.
Parameters
----------
shorthand : dict or string
The shorthand representation to be parsed
data : DataFrame, optional
If specified and of type DataFrame, then use these values to infer the
column type if not provided by the shorthand.
parse_aggregates : boolean
If True (default), then parse aggregate functions within the shorthand.
parse_window_ops : boolean
If True then parse window operations within the shorthand (default:False)
parse_timeunits : boolean
If True (default), then parse timeUnits from within the shorthand
parse_types : boolean
If True (default), then parse typecodes within the shorthand
Returns
-------
attrs : dict
a dictionary of attributes extracted from the shorthand
Examples
--------
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'],
... 'bar': [1, 2, 3, 4]})
>>> parse_shorthand('name') == {'field': 'name'}
True
>>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'}
True
>>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'}
True
>>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'}
True
>>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'}
True
>>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'}
True
>>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'}
True
>>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'}
True
>>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'}
True
>>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'}
True
>>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'}
True
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'}
True
"""
if not shorthand:
return {}
valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP)
units = dict(
field="(?P<field>.*)",
type="(?P<type>{})".format("|".join(valid_typecodes)),
agg_count="(?P<aggregate>count)",
op_count="(?P<op>count)",
aggregate="(?P<aggregate>{})".format("|".join(AGGREGATES)),
window_op="(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)),
timeUnit="(?P<timeUnit>{})".format("|".join(TIMEUNITS)),
)
patterns = []
if parse_aggregates:
patterns.extend([r"{agg_count}\(\)"])
patterns.extend([r"{aggregate}\({field}\)"])
if parse_window_ops:
patterns.extend([r"{op_count}\(\)"])
patterns.extend([r"{window_op}\({field}\)"])
if parse_timeunits:
patterns.extend([r"{timeUnit}\({field}\)"])
patterns.extend([r"{field}"])
if parse_types:
patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns)))
regexps = (
re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns
)
# find matches depending on valid fields passed
if isinstance(shorthand, dict):
attrs = shorthand
else:
attrs = next(
exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand)
)
# Handle short form of the type expression
if "type" in attrs:
attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"])
# counts are quantitative by default
if attrs == {"aggregate": "count"}:
attrs["type"] = "quantitative"
# times are temporal by default
if "timeUnit" in attrs and "type" not in attrs:
attrs["type"] = "temporal"
# if data is specified and type is not, infer type from data
if isinstance(data, pd.DataFrame) and "type" not in attrs:
if "field" in attrs and attrs["field"] in data.columns:
attrs["type"] = infer_vegalite_type(data[attrs["field"]])
return attrs
def use_signature(Obj):
"""Apply call signature and documentation of Obj to the decorated method"""
def decorate(f):
# call-signature of f is exposed via __wrapped__.
# we want it to mimic Obj.__init__
f.__wrapped__ = Obj.__init__
f._uses_signature = Obj
# Supplement the docstring of f with information from Obj
if Obj.__doc__:
doclines = Obj.__doc__.splitlines()
if f.__doc__:
doc = f.__doc__ + "\n".join(doclines[1:])
else:
doc = "\n".join(doclines)
try:
f.__doc__ = doc
except AttributeError:
# __doc__ is not modifiable for classes in Python < 3.3
pass
return f
return decorate
def update_subtraits(obj, attrs, **kwargs):
"""Recursively update sub-traits without overwriting other traits"""
# TODO: infer keywords from args
if not kwargs:
return obj
# obj can be a SchemaBase object or a dict
if obj is Undefined:
obj = dct = {}
elif isinstance(obj, SchemaBase):
dct = obj._kwds
else:
dct = obj
if isinstance(attrs, str):
attrs = (attrs,)
if len(attrs) == 0:
dct.update(kwargs)
else:
attr = attrs[0]
trait = dct.get(attr, Undefined)
if trait is Undefined:
trait = dct[attr] = {}
dct[attr] = update_subtraits(trait, attrs[1:], **kwargs)
return obj
def update_nested(original, update, copy=False):
"""Update nested dictionaries
Parameters
----------
original : dict
the original (nested) dictionary, which will be updated in-place
update : dict
the nested dictionary of updates
copy : bool, default False
if True, then copy the original dictionary rather than modifying it
Returns
-------
original : dict
a reference to the (modified) original dict
Examples
--------
>>> original = {'x': {'b': 2, 'c': 4}}
>>> update = {'x': {'b': 5, 'd': 6}, 'y': 40}
>>> update_nested(original, update) # doctest: +SKIP
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
>>> original # doctest: +SKIP
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
"""
if copy:
original = deepcopy(original)
for key, val in update.items():
if isinstance(val, Mapping):
orig_val = original.get(key, {})
if isinstance(orig_val, Mapping):
original[key] = update_nested(orig_val, val)
else:
original[key] = val
else:
original[key] = val
return original
def display_traceback(in_ipython=True):
exc_info = sys.exc_info()
if in_ipython:
from IPython.core.getipython import get_ipython
ip = get_ipython()
else:
ip = None
if ip is not None:
ip.showtraceback(exc_info)
else:
traceback.print_exception(*exc_info)
def infer_encoding_types(args, kwargs, channels):
"""Infer typed keyword arguments for args and kwargs
Parameters
----------
args : tuple
List of function args
kwargs : dict
Dict of function kwargs
channels : module
The module containing all altair encoding channel classes.
Returns
-------
kwargs : dict
All args and kwargs in a single dict, with keys and types
based on the channels mapping.
"""
# Construct a dictionary of channel type to encoding name
# TODO: cache this somehow?
channel_objs = (getattr(channels, name) for name in dir(channels))
channel_objs = (
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
)
channel_to_name = {c: c._encoding_name for c in channel_objs}
name_to_channel = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
if chan.__name__.endswith("Datum"):
key = "datum"
elif chan.__name__.endswith("Value"):
key = "value"
else:
key = "field"
chans[key] = chan
# First use the mapping to convert args to kwargs based on their types.
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0:
type_ = type(arg[0])
else:
type_ = type(arg)
encoding = channel_to_name.get(type_, None)
if encoding is None:
raise NotImplementedError("positional of type {}" "".format(type_))
if encoding in kwargs:
raise ValueError("encoding {} specified twice.".format(encoding))
kwargs[encoding] = arg
def _wrap_in_channel_class(obj, encoding):
try:
condition = obj["condition"]
except (KeyError, TypeError):
pass
else:
if condition is not Undefined:
obj = obj.copy()
obj["condition"] = _wrap_in_channel_class(condition, encoding)
if isinstance(obj, SchemaBase):
return obj
if isinstance(obj, str):
obj = {"shorthand": obj}
if isinstance(obj, (list, tuple)):
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]
if encoding not in name_to_channel:
warnings.warn("Unrecognized encoding channel '{}'".format(encoding))
return obj
classes = name_to_channel[encoding]
cls = classes["value"] if "value" in obj else classes["field"]
try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return cls.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj
return {
encoding: _wrap_in_channel_class(obj, encoding)
for encoding, obj in kwargs.items()
}

View File

@@ -0,0 +1,244 @@
import json
import os
import random
import hashlib
import warnings
import pandas as pd
from toolz import curried
from typing import Callable
from .core import sanitize_dataframe
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
# ==============================================================================
# Data transformer registry
# ==============================================================================
DataTransformerType = Callable
class DataTransformerRegistry(PluginRegistry[DataTransformerType]):
_global_settings = {"consolidate_datasets": True}
@property
def consolidate_datasets(self):
return self._global_settings["consolidate_datasets"]
@consolidate_datasets.setter
def consolidate_datasets(self, value):
self._global_settings["consolidate_datasets"] = value
# ==============================================================================
# Data model transformers
#
# A data model transformer is a pure function that takes a dict or DataFrame
# and returns a transformed version of a dict or DataFrame. The dict objects
# will be the Data portion of the VegaLite schema. The idea is that user can
# pipe a sequence of these data transformers together to prepare the data before
# it hits the renderer.
#
# In this version of Altair, renderers only deal with the dict form of a
# VegaLite spec, after the Data model has been put into a schema compliant
# form.
#
# A data model transformer has the following type signature:
# DataModelType = Union[dict, pd.DataFrame]
# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType]
# ==============================================================================
class MaxRowsError(Exception):
"""Raised when a data model has too many rows."""
pass
@curried.curry
def limit_rows(data, max_rows=5000):
"""Raise MaxRowsError if the data model has more than max_rows.
If max_rows is None, then do not perform any check.
"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if data.__geo_interface__["type"] == "FeatureCollection":
values = data.__geo_interface__["features"]
else:
values = data.__geo_interface__
elif isinstance(data, pd.DataFrame):
values = data
elif isinstance(data, dict):
if "values" in data:
values = data["values"]
else:
return data
if max_rows is not None and len(values) > max_rows:
raise MaxRowsError(
"The number of rows in your dataset is greater "
"than the maximum allowed ({}). "
"For information on how to plot larger datasets "
"in Altair, see the documentation".format(max_rows)
)
return data
@curried.curry
def sample(data, n=None, frac=None):
"""Reduce the size of the data model by sampling without replacement."""
check_data_type(data)
if isinstance(data, pd.DataFrame):
return data.sample(n=n, frac=frac)
elif isinstance(data, dict):
if "values" in data:
values = data["values"]
n = n if n else int(frac * len(values))
values = random.sample(values, n)
return {"values": values}
@curried.curry
def to_json(
data,
prefix="altair-data",
extension="json",
filename="{prefix}-{hash}.{extension}",
urlpath="",
):
"""
Write the data model to a .json file and return a url based data model.
"""
data_json = _data_to_json_string(data)
data_hash = _compute_data_hash(data_json)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
with open(filename, "w") as f:
f.write(data_json)
return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}}
@curried.curry
def to_csv(
data,
prefix="altair-data",
extension="csv",
filename="{prefix}-{hash}.{extension}",
urlpath="",
):
"""Write the data model to a .csv file and return a url based data model."""
data_csv = _data_to_csv_string(data)
data_hash = _compute_data_hash(data_csv)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
with open(filename, "w") as f:
f.write(data_csv)
return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}}
@curried.curry
def to_values(data):
"""Replace a DataFrame by a data model with values."""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
data = sanitize_geo_interface(data.__geo_interface__)
return {"values": data}
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return {"values": data.to_dict(orient="records")}
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return data
def check_data_type(data):
"""Raise if the data is not a dict or DataFrame."""
if not isinstance(data, (dict, pd.DataFrame)) and not hasattr(
data, "__geo_interface__"
):
raise TypeError(
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format(
type(data)
)
)
# ==============================================================================
# Private utilities
# ==============================================================================
def _compute_data_hash(data_str):
return hashlib.md5(data_str.encode()).hexdigest()
def _data_to_json_string(data):
"""Return a JSON string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
if isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
data = sanitize_geo_interface(data.__geo_interface__)
return json.dumps(data)
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return data.to_json(orient="records", double_precision=15)
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return json.dumps(data["values"], sort_keys=True)
else:
raise NotImplementedError(
"to_json only works with data expressed as " "a DataFrame or as a dict"
)
def _data_to_csv_string(data):
"""return a CSV string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
raise NotImplementedError(
"to_csv does not work with data that "
"contains the __geo_interface__ attribute"
)
elif isinstance(data, pd.DataFrame):
data = sanitize_dataframe(data)
return data.to_csv(index=False)
elif isinstance(data, dict):
if "values" not in data:
raise KeyError("values expected in data dict, but not present")
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
else:
raise NotImplementedError(
"to_csv only works with data expressed as " "a DataFrame or as a dict"
)
def pipe(data, *funcs):
"""
Pipe a value through a sequence of functions
Deprecated: use toolz.curried.pipe() instead.
"""
warnings.warn(
"alt.pipe() is deprecated, and will be removed in a future release. "
"Use toolz.curried.pipe() instead.",
AltairDeprecationWarning,
)
return curried.pipe(data, *funcs)
def curry(*args, **kwargs):
"""Curry a callable function
Deprecated: use toolz.curried.curry() instead.
"""
warnings.warn(
"alt.curry() is deprecated, and will be removed in a future release. "
"Use toolz.curried.curry() instead.",
AltairDeprecationWarning,
)
return curried.curry(*args, **kwargs)

View File

@@ -0,0 +1,70 @@
import warnings
import functools
class AltairDeprecationWarning(UserWarning):
pass
def deprecated(message=None):
"""Decorator to deprecate a function or class.
Parameters
----------
message : string (optional)
The deprecation message
"""
def wrapper(obj):
return _deprecate(obj, message=message)
return wrapper
def _deprecate(obj, name=None, message=None):
"""Return a version of a class or function that raises a deprecation warning.
Parameters
----------
obj : class or function
The object to create a deprecated version of.
name : string (optional)
The name of the deprecated object
message : string (optional)
The deprecation message
Returns
-------
deprecated_obj :
The deprecated version of obj
Examples
--------
>>> class Foo(object): pass
>>> OldFoo = _deprecate(Foo, "OldFoo")
>>> f = OldFoo() # doctest: +SKIP
AltairDeprecationWarning: alt.OldFoo is deprecated. Use alt.Foo instead.
"""
if message is None:
message = "alt.{} is deprecated. Use alt.{} instead." "".format(
name, obj.__name__
)
if isinstance(obj, type):
return type(
name,
(obj,),
{
"__doc__": obj.__doc__,
"__init__": _deprecate(obj.__init__, "__init__", message),
},
)
elif callable(obj):
@functools.wraps(obj)
def new_obj(*args, **kwargs):
warnings.warn(message, AltairDeprecationWarning)
return obj(*args, **kwargs)
return new_obj
else:
raise ValueError("Cannot deprecate object of type {}".format(type(obj)))

View File

@@ -0,0 +1,182 @@
import json
import pkgutil
import textwrap
from typing import Callable, Dict
import uuid
from jsonschema import validate
from .plugin_registry import PluginRegistry
from .mimebundle import spec_to_mimebundle
# ==============================================================================
# Renderer registry
# ==============================================================================
MimeBundleType = Dict[str, object]
RendererType = Callable[..., MimeBundleType]
class RendererRegistry(PluginRegistry[RendererType]):
entrypoint_err_messages = {
"notebook": textwrap.dedent(
"""
To use the 'notebook' renderer, you must install the vega package
and the associated Jupyter extension.
See https://altair-viz.github.io/getting_started/installation.html
for more information.
"""
),
"altair_viewer": textwrap.dedent(
"""
To use the 'altair_viewer' renderer, you must install the altair_viewer
package; see http://github.com/altair-viz/altair_viewer/
for more information.
"""
),
}
def set_embed_options(
self,
defaultStyle=None,
renderer=None,
width=None,
height=None,
padding=None,
scaleFactor=None,
actions=None,
**kwargs,
):
"""Set options for embeddings of Vega & Vega-Lite charts.
Options are fully documented at https://github.com/vega/vega-embed.
Similar to the `enable()` method, this can be used as either
a persistent global switch, or as a temporary local setting using
a context manager (i.e. a `with` statement).
Parameters
----------
defaultStyle : bool or string
Specify a default stylesheet for embed actions.
renderer : string
The renderer to use for the view. One of "canvas" (default) or "svg"
width : integer
The view width in pixels
height : integer
The view height in pixels
padding : integer
The view padding in pixels
scaleFactor : number
The number by which to multiply the width and height (default 1)
of an exported PNG or SVG image.
actions : bool or dict
Determines if action links ("Export as PNG/SVG", "View Source",
"View Vega" (only for Vega-Lite), "Open in Vega Editor") are
included with the embedded view. If the value is true, all action
links will be shown and none if the value is false. This property
can take a key-value mapping object that maps keys (export, source,
compiled, editor) to boolean values for determining if
each action link should be shown.
**kwargs :
Additional options are passed directly to embed options.
"""
options = {
"defaultStyle": defaultStyle,
"renderer": renderer,
"width": width,
"height": height,
"padding": padding,
"scaleFactor": scaleFactor,
"actions": actions,
}
kwargs.update({key: val for key, val in options.items() if val is not None})
return self.enable(None, embed_options=kwargs)
# ==============================================================================
# VegaLite v1/v2 renderer logic
# ==============================================================================
class Displayable(object):
"""A base display class for VegaLite v1/v2.
This class takes a VegaLite v1/v2 spec and does the following:
1. Optionally validates the spec against a schema.
2. Uses the RendererPlugin to grab a renderer and call it when the
IPython/Jupyter display method (_repr_mimebundle_) is called.
The spec passed to this class must be fully schema compliant and already
have the data portion of the spec fully processed and ready to serialize.
In practice, this means, the data portion of the spec should have been passed
through appropriate data model transformers.
"""
renderers = None
schema_path = ("altair", "")
def __init__(self, spec, validate=False):
# type: (dict, bool) -> None
self.spec = spec
self.validate = validate
self._validate()
def _validate(self):
# type: () -> None
"""Validate the spec against the schema."""
schema_dict = json.loads(pkgutil.get_data(*self.schema_path).decode("utf-8"))
validate(self.spec, schema_dict)
def _repr_mimebundle_(self, include=None, exclude=None):
"""Return a MIME bundle for display in Jupyter frontends."""
if self.renderers is not None:
return self.renderers.get()(self.spec)
else:
return {}
def default_renderer_base(spec, mime_type, str_repr, **options):
"""A default renderer for Vega or VegaLite that works for modern frontends.
This renderer works with modern frontends (JupyterLab, nteract) that know
how to render the custom VegaLite MIME type listed above.
"""
assert isinstance(spec, dict)
bundle = {}
metadata = {}
bundle[mime_type] = spec
bundle["text/plain"] = str_repr
if options:
metadata[mime_type] = options
return bundle, metadata
def json_renderer_base(spec, str_repr, **options):
"""A renderer that returns a MIME type of application/json.
In JupyterLab/nteract this is rendered as a nice JSON tree.
"""
return default_renderer_base(
spec, mime_type="application/json", str_repr=str_repr, **options
)
class HTMLRenderer(object):
"""Object to render charts as HTML, with a unique output div each time"""
def __init__(self, output_div="altair-viz-{}", **kwargs):
self._output_div = output_div
self.kwargs = kwargs
@property
def output_div(self):
return self._output_div.format(uuid.uuid4().hex)
def __call__(self, spec, **metadata):
kwargs = self.kwargs.copy()
kwargs.update(metadata)
return spec_to_mimebundle(
spec, format="html", output_div=self.output_div, **kwargs
)

View File

@@ -0,0 +1,61 @@
import ast
import sys
if sys.version_info > (3, 8):
Module = ast.Module
else:
# Mock the Python >= 3.8 API
def Module(nodelist, type_ignores):
return ast.Module(nodelist)
class _CatchDisplay(object):
"""Class to temporarily catch sys.displayhook"""
def __init__(self):
self.output = None
def __enter__(self):
self.old_hook = sys.displayhook
sys.displayhook = self
return self
def __exit__(self, type, value, traceback):
sys.displayhook = self.old_hook
# Returning False will cause exceptions to propagate
return False
def __call__(self, output):
self.output = output
def eval_block(code, namespace=None, filename="<string>"):
"""
Execute a multi-line block of code in the given namespace
If the final statement in the code is an expression, return
the result of the expression.
"""
tree = ast.parse(code, filename="<ast>", mode="exec")
if namespace is None:
namespace = {}
catch_display = _CatchDisplay()
if isinstance(tree.body[-1], ast.Expr):
to_exec, to_eval = tree.body[:-1], tree.body[-1:]
else:
to_exec, to_eval = tree.body, []
for node in to_exec:
compiled = compile(Module([node], []), filename=filename, mode="exec")
exec(compiled, namespace)
with catch_display:
for node in to_eval:
compiled = compile(
ast.Interactive([node]), filename=filename, mode="single"
)
exec(compiled, namespace)
return catch_display.output

View File

@@ -0,0 +1,236 @@
import json
import jinja2
HTML_TEMPLATE = jinja2.Template(
"""
{%- if fullhtml -%}
<!DOCTYPE html>
<html>
<head>
{%- endif %}
<style>
.error {
color: red;
}
</style>
{%- if not requirejs %}
<script type="text/javascript" src="{{ base_url }}/vega@{{ vega_version }}"></script>
{%- if mode == 'vega-lite' %}
<script type="text/javascript" src="{{ base_url }}/vega-lite@{{ vegalite_version }}"></script>
{%- endif %}
<script type="text/javascript" src="{{ base_url }}/vega-embed@{{ vegaembed_version }}"></script>
{%- endif %}
{%- if fullhtml %}
{%- if requirejs %}
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
<script>
requirejs.config({
"paths": {
"vega": "{{ base_url }}/vega@{{ vega_version }}?noext",
"vega-lib": "{{ base_url }}/vega-lib?noext",
"vega-lite": "{{ base_url }}/vega-lite@{{ vegalite_version }}?noext",
"vega-embed": "{{ base_url }}/vega-embed@{{ vegaembed_version }}?noext",
}
});
</script>
{%- endif %}
</head>
<body>
{%- endif %}
<div id="{{ output_div }}"></div>
<script>
{%- if requirejs and not fullhtml %}
requirejs.config({
"paths": {
"vega": "{{ base_url }}/vega@{{ vega_version }}?noext",
"vega-lib": "{{ base_url }}/vega-lib?noext",
"vega-lite": "{{ base_url }}/vega-lite@{{ vegalite_version }}?noext",
"vega-embed": "{{ base_url }}/vega-embed@{{ vegaembed_version }}?noext",
}
});
{% endif %}
{% if requirejs -%}
require(['vega-embed'],
{%- else -%}
(
{%- endif -%}
function(vegaEmbed) {
var spec = {{ spec }};
var embedOpt = {{ embed_options }};
function showError(el, error){
el.innerHTML = ('<div class="error" style="color:red;">'
+ '<p>JavaScript Error: ' + error.message + '</p>'
+ "<p>This usually means there's a typo in your chart specification. "
+ "See the javascript console for the full traceback.</p>"
+ '</div>');
throw error;
}
const el = document.getElementById('{{ output_div }}');
vegaEmbed("#{{ output_div }}", spec, embedOpt)
.catch(error => showError(el, error));
}){% if not requirejs %}(vegaEmbed){% endif %};
</script>
{%- if fullhtml %}
</body>
</html>
{%- endif %}
"""
)
HTML_TEMPLATE_UNIVERSAL = jinja2.Template(
"""
<div id="{{ output_div }}"></div>
<script type="text/javascript">
var VEGA_DEBUG = (typeof VEGA_DEBUG == "undefined") ? {} : VEGA_DEBUG;
(function(spec, embedOpt){
let outputDiv = document.currentScript.previousElementSibling;
if (outputDiv.id !== "{{ output_div }}") {
outputDiv = document.getElementById("{{ output_div }}");
}
const paths = {
"vega": "{{ base_url }}/vega@{{ vega_version }}?noext",
"vega-lib": "{{ base_url }}/vega-lib?noext",
"vega-lite": "{{ base_url }}/vega-lite@{{ vegalite_version }}?noext",
"vega-embed": "{{ base_url }}/vega-embed@{{ vegaembed_version }}?noext",
};
function maybeLoadScript(lib, version) {
var key = `${lib.replace("-", "")}_version`;
return (VEGA_DEBUG[key] == version) ?
Promise.resolve(paths[lib]) :
new Promise(function(resolve, reject) {
var s = document.createElement('script');
document.getElementsByTagName("head")[0].appendChild(s);
s.async = true;
s.onload = () => {
VEGA_DEBUG[key] = version;
return resolve(paths[lib]);
};
s.onerror = () => reject(`Error loading script: ${paths[lib]}`);
s.src = paths[lib];
});
}
function showError(err) {
outputDiv.innerHTML = `<div class="error" style="color:red;">${err}</div>`;
throw err;
}
function displayChart(vegaEmbed) {
vegaEmbed(outputDiv, spec, embedOpt)
.catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));
}
if(typeof define === "function" && define.amd) {
requirejs.config({paths});
require(["vega-embed"], displayChart, err => showError(`Error loading script: ${err.message}`));
} else {
maybeLoadScript("vega", "{{vega_version}}")
.then(() => maybeLoadScript("vega-lite", "{{vegalite_version}}"))
.then(() => maybeLoadScript("vega-embed", "{{vegaembed_version}}"))
.catch(showError)
.then(() => displayChart(vegaEmbed));
}
})({{ spec }}, {{ embed_options }});
</script>
"""
)
TEMPLATES = {
"standard": HTML_TEMPLATE,
"universal": HTML_TEMPLATE_UNIVERSAL,
}
def spec_to_html(
spec,
mode,
vega_version,
vegaembed_version,
vegalite_version=None,
base_url="https://cdn.jsdelivr.net/npm/",
output_div="vis",
embed_options=None,
json_kwds=None,
fullhtml=True,
requirejs=False,
template="standard",
):
"""Embed a Vega/Vega-Lite spec into an HTML page
Parameters
----------
spec : dict
a dictionary representing a vega-lite plot spec.
mode : string {'vega' | 'vega-lite'}
The rendering mode. This value is overridden by embed_options['mode'],
if it is present.
vega_version : string
For html output, the version of vega.js to use.
vegalite_version : string
For html output, the version of vegalite.js to use.
vegaembed_version : string
For html output, the version of vegaembed.js to use.
base_url : string (optional)
The base url from which to load the javascript libraries.
output_div : string (optional)
The id of the div element where the plot will be shown.
embed_options : dict (optional)
Dictionary of options to pass to the vega-embed script. Default
entry is {'mode': mode}.
json_kwds : dict (optional)
Dictionary of keywords to pass to json.dumps().
fullhtml : boolean (optional)
If True (default) then return a full html page. If False, then return
an HTML snippet that can be embedded into an HTML page.
requirejs : boolean (optional)
If False (default) then load libraries from base_url using <script>
tags. If True, then load libraries using requirejs
template : jinja2.Template or string (optional)
Specify the template to use (default = 'standard'). If template is a
string, it must be one of {'universal', 'standard'}. Otherwise, it
can be a jinja2.Template object containing a custom template.
Returns
-------
output : string
an HTML string for rendering the chart.
"""
embed_options = embed_options or {}
json_kwds = json_kwds or {}
mode = embed_options.setdefault("mode", mode)
if mode not in ["vega", "vega-lite"]:
raise ValueError("mode must be either 'vega' or 'vega-lite'")
if vega_version is None:
raise ValueError("must specify vega_version")
if vegaembed_version is None:
raise ValueError("must specify vegaembed_version")
if mode == "vega-lite" and vegalite_version is None:
raise ValueError("must specify vega-lite version for mode='vega-lite'")
template = TEMPLATES.get(template, template)
if not hasattr(template, "render"):
raise ValueError("Invalid template: {0}".format(template))
return template.render(
spec=json.dumps(spec, **json_kwds),
embed_options=json.dumps(embed_options),
mode=mode,
vega_version=vega_version,
vegalite_version=vegalite_version,
vegaembed_version=vegaembed_version,
base_url=base_url,
output_div=output_div,
fullhtml=fullhtml,
requirejs=requirejs,
)

View File

@@ -0,0 +1,83 @@
from .html import spec_to_html
def spec_to_mimebundle(
spec,
format,
mode=None,
vega_version=None,
vegaembed_version=None,
vegalite_version=None,
**kwargs,
):
"""Convert a vega/vega-lite specification to a mimebundle
The mimebundle type is controlled by the ``format`` argument, which can be
one of the following ['html', 'json', 'png', 'svg', 'pdf', 'vega', 'vega-lite']
Parameters
----------
spec : dict
a dictionary representing a vega-lite plot spec
format : string {'html', 'json', 'png', 'svg', 'pdf', 'vega', 'vega-lite'}
the file format to be saved.
mode : string {'vega', 'vega-lite'}
The rendering mode.
vega_version : string
The version of vega.js to use
vegaembed_version : string
The version of vegaembed.js to use
vegalite_version : string
The version of vegalite.js to use. Only required if mode=='vega-lite'
**kwargs :
Additional arguments will be passed to the generating function
Returns
-------
output : dict
a mime-bundle representing the image
Note
----
The png, svg, pdf, and vega outputs require the altair_saver package
to be installed.
"""
if mode not in ["vega", "vega-lite"]:
raise ValueError("mode must be either 'vega' or 'vega-lite'")
if mode == "vega" and format == "vega":
if vega_version is None:
raise ValueError("Must specify vega_version")
return {"application/vnd.vega.v{}+json".format(vega_version[0]): spec}
if format in ["png", "svg", "pdf", "vega"]:
try:
import altair_saver
except ImportError:
raise ValueError(
"Saving charts in {fmt!r} format requires the altair_saver package: "
"see http://github.com/altair-viz/altair_saver/".format(fmt=format)
)
return altair_saver.render(spec, format, mode=mode, **kwargs)
if format == "html":
html = spec_to_html(
spec,
mode=mode,
vega_version=vega_version,
vegaembed_version=vegaembed_version,
vegalite_version=vegalite_version,
**kwargs,
)
return {"text/html": html}
if format == "vega-lite":
assert mode == "vega-lite" # sanity check: should never be False
if mode == "vega":
raise ValueError("Cannot convert a vega spec to vegalite")
if vegalite_version is None:
raise ValueError("Must specify vegalite_version")
return {"application/vnd.vegalite.v{}+json".format(vegalite_version[0]): spec}
if format == "json":
return {"application/json": spec}
raise ValueError(
"format must be one of "
"['html', 'json', 'png', 'svg', 'pdf', 'vega', 'vega-lite']"
)

View File

@@ -0,0 +1,199 @@
from typing import Any, Dict, List, Optional, Generic, TypeVar, cast
from types import TracebackType
import entrypoints
from toolz import curry
PluginType = TypeVar("PluginType")
class PluginEnabler(object):
"""Context manager for enabling plugins
This object lets you use enable() as a context manager to
temporarily enable a given plugin::
with plugins.enable('name'):
do_something() # 'name' plugin temporarily enabled
# plugins back to original state
"""
def __init__(self, registry: "PluginRegistry", name: str, **options):
self.registry = registry # type: PluginRegistry
self.name = name # type: str
self.options = options # type: Dict[str, Any]
self.original_state = registry._get_state() # type: Dict[str, Any]
self.registry._enable(name, **options)
def __enter__(self) -> "PluginEnabler":
return self
def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None:
self.registry._set_state(self.original_state)
def __repr__(self) -> str:
return "{}.enable({!r})".format(self.registry.__class__.__name__, self.name)
class PluginRegistry(Generic[PluginType]):
"""A registry for plugins.
This is a plugin registry that allows plugins to be loaded/registered
in two ways:
1. Through an explicit call to ``.register(name, value)``.
2. By looking for other Python packages that are installed and provide
a setuptools entry point group.
When you create an instance of this class, provide the name of the
entry point group to use::
reg = PluginRegister('my_entrypoint_group')
"""
# this is a mapping of name to error message to allow custom error messages
# in case an entrypoint is not found
entrypoint_err_messages = {} # type: Dict[str, str]
# global settings is a key-value mapping of settings that are stored globally
# in the registry rather than passed to the plugins
_global_settings = {} # type: Dict[str, Any]
def __init__(self, entry_point_group: str = "", plugin_type: type = object):
"""Create a PluginRegistry for a named entry point group.
Parameters
==========
entry_point_group: str
The name of the entry point group.
plugin_type: object
A type that will optionally be used for runtime type checking of
loaded plugins using isinstance.
"""
self.entry_point_group = entry_point_group # type: str
self.plugin_type = plugin_type # type: Optional[type]
self._active = None # type: Optional[PluginType]
self._active_name = "" # type: str
self._plugins = {} # type: Dict[str, PluginType]
self._options = {} # type: Dict[str, Any]
self._global_settings = self.__class__._global_settings.copy() # type: dict
def register(self, name: str, value: Optional[PluginType]) -> Optional[PluginType]:
"""Register a plugin by name and value.
This method is used for explicit registration of a plugin and shouldn't be
used to manage entry point managed plugins, which are auto-loaded.
Parameters
==========
name: str
The name of the plugin.
value: PluginType or None
The actual plugin object to register or None to unregister that plugin.
Returns
=======
plugin: PluginType or None
The plugin that was registered or unregistered.
"""
if value is None:
return self._plugins.pop(name, None)
else:
assert isinstance(value, self.plugin_type)
self._plugins[name] = value
return value
def names(self) -> List[str]:
"""List the names of the registered and entry points plugins."""
exts = list(self._plugins.keys())
more_exts = [
ep.name for ep in entrypoints.get_group_all(self.entry_point_group)
]
exts.extend(more_exts)
return sorted(set(exts))
def _get_state(self) -> Dict[str, Any]:
"""Return a dictionary representing the current state of the registry"""
return {
"_active": self._active,
"_active_name": self._active_name,
"_plugins": self._plugins.copy(),
"_options": self._options.copy(),
"_global_settings": self._global_settings.copy(),
}
def _set_state(self, state: Dict[str, Any]) -> None:
"""Reset the state of the registry"""
assert set(state.keys()) == {
"_active",
"_active_name",
"_plugins",
"_options",
"_global_settings",
}
for key, val in state.items():
setattr(self, key, val)
def _enable(self, name: str, **options) -> None:
if name not in self._plugins:
try:
ep = entrypoints.get_single(self.entry_point_group, name)
except entrypoints.NoSuchEntryPoint:
if name in self.entrypoint_err_messages:
raise ValueError(self.entrypoint_err_messages[name])
else:
raise
value = cast(PluginType, ep.load())
self.register(name, value)
self._active_name = name
self._active = self._plugins[name]
for key in set(options.keys()) & set(self._global_settings.keys()):
self._global_settings[key] = options.pop(key)
self._options = options
def enable(self, name: Optional[str] = None, **options) -> PluginEnabler:
"""Enable a plugin by name.
This can be either called directly, or used as a context manager.
Parameters
----------
name : string (optional)
The name of the plugin to enable. If not specified, then use the
current active name.
**options :
Any additional parameters will be passed to the plugin as keyword
arguments
Returns
-------
PluginEnabler:
An object that allows enable() to be used as a context manager
"""
if name is None:
name = self.active
return PluginEnabler(self, name, **options)
@property
def active(self) -> str:
"""Return the name of the currently active plugin"""
return self._active_name
@property
def options(self) -> Dict[str, Any]:
"""Return the current options dictionary"""
return self._options
def get(self) -> Optional[PluginType]:
"""Return the currently active plugin."""
if self._options:
return curry(self._active, **self._options)
else:
return self._active
def __repr__(self) -> str:
return "{}(active={!r}, registered={!r})" "".format(
self.__class__.__name__, self._active_name, list(self.names())
)

View File

@@ -0,0 +1,134 @@
import json
import pathlib
from .mimebundle import spec_to_mimebundle
def write_file_or_filename(fp, content, mode="w"):
"""Write content to fp, whether fp is a string, a pathlib Path or a
file-like object"""
if isinstance(fp, str) or isinstance(fp, pathlib.PurePath):
with open(fp, mode) as f:
f.write(content)
else:
fp.write(content)
def save(
chart,
fp,
vega_version,
vegaembed_version,
format=None,
mode=None,
vegalite_version=None,
embed_options=None,
json_kwds=None,
webdriver="chrome",
scale_factor=1,
**kwargs,
):
"""Save a chart to file in a variety of formats
Supported formats are [json, html, png, svg]
Parameters
----------
chart : alt.Chart
the chart instance to save
fp : string filename, pathlib.Path or file-like object
file to which to write the chart.
format : string (optional)
the format to write: one of ['json', 'html', 'png', 'svg'].
If not specified, the format will be determined from the filename.
mode : string (optional)
Either 'vega' or 'vegalite'. If not specified, then infer the mode from
the '$schema' property of the spec, or the ``opt`` dictionary.
If it's not specified in either of those places, then use 'vegalite'.
vega_version : string
For html output, the version of vega.js to use
vegalite_version : string
For html output, the version of vegalite.js to use
vegaembed_version : string
For html output, the version of vegaembed.js to use
embed_options : dict
The vegaEmbed options dictionary. Default is {}
(See https://github.com/vega/vega-embed for details)
json_kwds : dict
Additional keyword arguments are passed to the output method
associated with the specified format.
webdriver : string {'chrome' | 'firefox'}
Webdriver to use for png or svg output
scale_factor : float
scale_factor to use to change size/resolution of png or svg output
**kwargs :
additional kwargs passed to spec_to_mimebundle.
"""
if json_kwds is None:
json_kwds = {}
if embed_options is None:
embed_options = {}
if format is None:
if isinstance(fp, str):
format = fp.split(".")[-1]
elif isinstance(fp, pathlib.PurePath):
format = fp.suffix.lstrip(".")
else:
raise ValueError(
"must specify file format: " "['png', 'svg', 'pdf', 'html', 'json']"
)
spec = chart.to_dict()
if mode is None:
if "mode" in embed_options:
mode = embed_options["mode"]
elif "$schema" in spec:
mode = spec["$schema"].split("/")[-2]
else:
mode = "vega-lite"
if mode not in ["vega", "vega-lite"]:
raise ValueError("mode must be 'vega' or 'vega-lite', " "not '{}'".format(mode))
if mode == "vega-lite" and vegalite_version is None:
raise ValueError("must specify vega-lite version")
if format == "json":
json_spec = json.dumps(spec, **json_kwds)
write_file_or_filename(fp, json_spec, mode="w")
elif format == "html":
mimebundle = spec_to_mimebundle(
spec=spec,
format=format,
mode=mode,
vega_version=vega_version,
vegalite_version=vegalite_version,
vegaembed_version=vegaembed_version,
embed_options=embed_options,
json_kwds=json_kwds,
**kwargs,
)
write_file_or_filename(fp, mimebundle["text/html"], mode="w")
elif format in ["png", "svg", "pdf"]:
mimebundle = spec_to_mimebundle(
spec=spec,
format=format,
mode=mode,
vega_version=vega_version,
vegalite_version=vegalite_version,
vegaembed_version=vegaembed_version,
webdriver=webdriver,
scale_factor=scale_factor,
**kwargs,
)
if format == "png":
write_file_or_filename(fp, mimebundle["image/png"], mode="wb")
elif format == "pdf":
write_file_or_filename(fp, mimebundle["application/pdf"], mode="wb")
else:
write_file_or_filename(fp, mimebundle["image/svg+xml"], mode="w")
else:
raise ValueError("unrecognized format: '{}'".format(format))

View File

@@ -0,0 +1,587 @@
# The contents of this file are automatically written by
# tools/generate_schema_wrapper.py. Do not modify directly.
import collections
import contextlib
import inspect
import json
import jsonschema
import numpy as np
import pandas as pd
# If DEBUG_MODE is True, then schema objects are converted to dict and
# validated at creation time. This slows things down, particularly for
# larger specs, but leads to much more useful tracebacks for the user.
# Individual schema classes can override this by setting the
# class-level _class_is_valid_at_instantiation attribute to False
DEBUG_MODE = True
def enable_debug_mode():
global DEBUG_MODE
DEBUG_MODE = True
def disable_debug_mode():
global DEBUG_MODE
DEBUG_MODE = True
@contextlib.contextmanager
def debug_mode(arg):
global DEBUG_MODE
original = DEBUG_MODE
DEBUG_MODE = arg
try:
yield
finally:
DEBUG_MODE = original
def _subclasses(cls):
"""Breadth-first sequence of all classes which inherit from cls."""
seen = set()
current_set = {cls}
while current_set:
seen |= current_set
current_set = set.union(*(set(cls.__subclasses__()) for cls in current_set))
for cls in current_set - seen:
yield cls
def _todict(obj, validate, context):
"""Convert an object to a dict representation."""
if isinstance(obj, SchemaBase):
return obj.to_dict(validate=validate, context=context)
elif isinstance(obj, (list, tuple, np.ndarray)):
return [_todict(v, validate, context) for v in obj]
elif isinstance(obj, dict):
return {
k: _todict(v, validate, context)
for k, v in obj.items()
if v is not Undefined
}
elif hasattr(obj, "to_dict"):
return obj.to_dict()
elif isinstance(obj, np.number):
return float(obj)
elif isinstance(obj, (pd.Timestamp, np.datetime64)):
return pd.Timestamp(obj).isoformat()
else:
return obj
def _resolve_references(schema, root=None):
"""Resolve schema references."""
resolver = jsonschema.RefResolver.from_schema(root or schema)
while "$ref" in schema:
with resolver.resolving(schema["$ref"]) as resolved:
schema = resolved
return schema
class SchemaValidationError(jsonschema.ValidationError):
"""A wrapper for jsonschema.ValidationError with friendlier traceback"""
def __init__(self, obj, err):
super(SchemaValidationError, self).__init__(**self._get_contents(err))
self.obj = obj
@staticmethod
def _get_contents(err):
"""Get a dictionary with the contents of a ValidationError"""
try:
# works in jsonschema 2.3 or later
contents = err._contents()
except AttributeError:
try:
# works in Python >=3.4
spec = inspect.getfullargspec(err.__init__)
except AttributeError:
# works in Python <3.4
spec = inspect.getargspec(err.__init__)
contents = {key: getattr(err, key) for key in spec.args[1:]}
return contents
def __str__(self):
cls = self.obj.__class__
schema_path = ["{}.{}".format(cls.__module__, cls.__name__)]
schema_path.extend(self.schema_path)
schema_path = "->".join(
str(val)
for val in schema_path[:-1]
if val not in ("properties", "additionalProperties", "patternProperties")
)
return """Invalid specification
{}, validating {!r}
{}
""".format(
schema_path, self.validator, self.message
)
class UndefinedType(object):
"""A singleton object for marking undefined attributes"""
__instance = None
def __new__(cls, *args, **kwargs):
if not isinstance(cls.__instance, cls):
cls.__instance = object.__new__(cls, *args, **kwargs)
return cls.__instance
def __repr__(self):
return "Undefined"
Undefined = UndefinedType()
class SchemaBase(object):
"""Base class for schema wrappers.
Each derived class should set the _schema class attribute (and optionally
the _rootschema class attribute) which is used for validation.
"""
_schema = None
_rootschema = None
_class_is_valid_at_instantiation = True
_validator = jsonschema.Draft7Validator
def __init__(self, *args, **kwds):
# Two valid options for initialization, which should be handled by
# derived classes:
# - a single arg with no kwds, for, e.g. {'type': 'string'}
# - zero args with zero or more kwds for {'type': 'object'}
if self._schema is None:
raise ValueError(
"Cannot instantiate object of type {}: "
"_schema class attribute is not defined."
"".format(self.__class__)
)
if kwds:
assert len(args) == 0
else:
assert len(args) in [0, 1]
# use object.__setattr__ because we override setattr below.
object.__setattr__(self, "_args", args)
object.__setattr__(self, "_kwds", kwds)
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)
def copy(self, deep=True, ignore=()):
"""Return a copy of the object
Parameters
----------
deep : boolean or list, optional
If True (default) then return a deep copy of all dict, list, and
SchemaBase objects within the object structure.
If False, then only copy the top object.
If a list or iterable, then only copy the listed attributes.
ignore : list, optional
A list of keys for which the contents should not be copied, but
only stored by reference.
"""
def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj
def _deep_copy(obj, ignore=()):
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj
try:
deep = list(deep)
except TypeError:
deep_is_list = False
else:
deep_is_list = True
if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
def _get(self, attr, default=Undefined):
"""Get an attribute, returning default if not present."""
attr = self._kwds.get(attr, Undefined)
if attr is Undefined:
attr = default
return attr
def __getattr__(self, attr):
# reminder: getattr is called after the normal lookups
if attr == "_kwds":
raise AttributeError()
if attr in self._kwds:
return self._kwds[attr]
else:
try:
_getattr = super(SchemaBase, self).__getattr__
except AttributeError:
_getattr = super(SchemaBase, self).__getattribute__
return _getattr(attr)
def __setattr__(self, item, val):
self._kwds[item] = val
def __getitem__(self, item):
return self._kwds[item]
def __setitem__(self, item, val):
self._kwds[item] = val
def __repr__(self):
if self._kwds:
args = (
"{}: {!r}".format(key, val)
for key, val in sorted(self._kwds.items())
if val is not Undefined
)
args = "\n" + ",\n".join(args)
return "{0}({{{1}\n}})".format(
self.__class__.__name__, args.replace("\n", "\n ")
)
else:
return "{}({!r})".format(self.__class__.__name__, self._args[0])
def __eq__(self, other):
return (
type(self) is type(other)
and self._args == other._args
and self._kwds == other._kwds
)
def to_dict(self, validate=True, ignore=None, context=None):
"""Return a dictionary representation of the object
Parameters
----------
validate : boolean or string
If True (default), then validate the output dictionary
against the schema. If "deep" then recursively validate
all objects in the spec. This takes much more time, but
it results in friendlier tracebacks for large objects.
ignore : list
A list of keys to ignore. This will *not* passed to child to_dict
function calls.
context : dict (optional)
A context dictionary that will be passed to all child to_dict
function calls
Returns
-------
dct : dictionary
The dictionary representation of this object
Raises
------
jsonschema.ValidationError :
if validate=True and the dict does not conform to the schema
"""
if context is None:
context = {}
if ignore is None:
ignore = []
sub_validate = "deep" if validate == "deep" else False
if self._args and not self._kwds:
result = _todict(self._args[0], validate=sub_validate, context=context)
elif not self._args:
result = _todict(
{k: v for k, v in self._kwds.items() if k not in ignore},
validate=sub_validate,
context=context,
)
else:
raise ValueError(
"{} instance has both a value and properties : "
"cannot serialize to dict".format(self.__class__)
)
if validate:
try:
self.validate(result)
except jsonschema.ValidationError as err:
raise SchemaValidationError(self, err)
return result
def to_json(
self, validate=True, ignore=[], context={}, indent=2, sort_keys=True, **kwargs
):
"""Emit the JSON representation for this object as a string.
Parameters
----------
validate : boolean or string
If True (default), then validate the output dictionary
against the schema. If "deep" then recursively validate
all objects in the spec. This takes much more time, but
it results in friendlier tracebacks for large objects.
ignore : list
A list of keys to ignore. This will *not* passed to child to_dict
function calls.
context : dict (optional)
A context dictionary that will be passed to all child to_dict
function calls
indent : integer, default 2
the number of spaces of indentation to use
sort_keys : boolean, default True
if True, sort keys in the output
**kwargs
Additional keyword arguments are passed to ``json.dumps()``
Returns
-------
spec : string
The JSON specification of the chart object.
"""
dct = self.to_dict(validate=validate, ignore=ignore, context=context)
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs)
@classmethod
def _default_wrapper_classes(cls):
"""Return the set of classes used within cls.from_dict()"""
return _subclasses(SchemaBase)
@classmethod
def from_dict(cls, dct, validate=True, _wrapper_classes=None):
"""Construct class from a dictionary representation
Parameters
----------
dct : dictionary
The dict from which to construct the class
validate : boolean
If True (default), then validate the input against the schema.
_wrapper_classes : list (optional)
The set of SchemaBase classes to use when constructing wrappers
of the dict inputs. If not specified, the result of
cls._default_wrapper_classes will be used.
Returns
-------
obj : Schema object
The wrapped schema
Raises
------
jsonschema.ValidationError :
if validate=True and dct does not conform to the schema
"""
if validate:
cls.validate(dct)
if _wrapper_classes is None:
_wrapper_classes = cls._default_wrapper_classes()
converter = _FromDict(_wrapper_classes)
return converter.from_dict(dct, cls)
@classmethod
def from_json(cls, json_string, validate=True, **kwargs):
"""Instantiate the object from a valid JSON string
Parameters
----------
json_string : string
The string containing a valid JSON chart specification.
validate : boolean
If True (default), then validate the input against the schema.
**kwargs :
Additional keyword arguments are passed to json.loads
Returns
-------
chart : Chart object
The altair Chart object built from the specification.
"""
dct = json.loads(json_string, **kwargs)
return cls.from_dict(dct, validate=validate)
@classmethod
def validate(cls, instance, schema=None):
"""
Validate the instance against the class schema in the context of the
rootschema.
"""
if schema is None:
schema = cls._schema
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
return jsonschema.validate(
instance, schema, cls=cls._validator, resolver=resolver
)
@classmethod
def resolve_references(cls, schema=None):
"""Resolve references in the context of this object's schema or root schema."""
return _resolve_references(
schema=(schema or cls._schema),
root=(cls._rootschema or cls._schema or schema),
)
@classmethod
def validate_property(cls, name, value, schema=None):
"""
Validate a property against property schema in the context of the
rootschema
"""
value = _todict(value, validate=False, context={})
props = cls.resolve_references(schema or cls._schema).get("properties", {})
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
return jsonschema.validate(value, props.get(name, {}), resolver=resolver)
def __dir__(self):
return list(self._kwds.keys())
def _passthrough(*args, **kwds):
return args[0] if args else kwds
class _FromDict(object):
"""Class used to construct SchemaBase class hierarchies from a dict
The primary purpose of using this class is to be able to build a hash table
that maps schemas to their wrapper classes. The candidate classes are
specified in the ``class_list`` argument to the constructor.
"""
_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")
def __init__(self, class_list):
# Create a mapping of a schema hash to a list of matching classes
# This lets us quickly determine the correct class to construct
self.class_dict = collections.defaultdict(list)
for cls in class_list:
if cls._schema is not None:
self.class_dict[self.hash_schema(cls._schema)].append(cls)
@classmethod
def hash_schema(cls, schema, use_json=True):
"""
Compute a python hash for a nested dictionary which
properly handles dicts, lists, sets, and tuples.
At the top level, the function excludes from the hashed schema all keys
listed in `exclude_keys`.
This implements two methods: one based on conversion to JSON, and one based
on recursive conversions of unhashable to hashable types; the former seems
to be slightly faster in several benchmarks.
"""
if cls._hash_exclude_keys and isinstance(schema, dict):
schema = {
key: val
for key, val in schema.items()
if key not in cls._hash_exclude_keys
}
if use_json:
s = json.dumps(schema, sort_keys=True)
return hash(s)
else:
def _freeze(val):
if isinstance(val, dict):
return frozenset((k, _freeze(v)) for k, v in val.items())
elif isinstance(val, set):
return frozenset(map(_freeze, val))
elif isinstance(val, list) or isinstance(val, tuple):
return tuple(map(_freeze, val))
else:
return val
return hash(_freeze(schema))
def from_dict(
self, dct, cls=None, schema=None, rootschema=None, default_class=_passthrough
):
"""Construct an object from a dict representation"""
if (schema is None) == (cls is None):
raise ValueError("Must provide either cls or schema, but not both.")
if schema is None:
schema = schema or cls._schema
rootschema = rootschema or cls._rootschema
rootschema = rootschema or schema
if isinstance(dct, SchemaBase):
return dct
if cls is None:
# If there are multiple matches, we use the first one in the dict.
# Our class dict is constructed breadth-first from top to bottom,
# so the first class that matches is the most general match.
matches = self.class_dict[self.hash_schema(schema)]
if matches:
cls = matches[0]
else:
cls = default_class
schema = _resolve_references(schema, rootschema)
if "anyOf" in schema or "oneOf" in schema:
schemas = schema.get("anyOf", []) + schema.get("oneOf", [])
for possible_schema in schemas:
resolver = jsonschema.RefResolver.from_schema(rootschema)
try:
jsonschema.validate(dct, possible_schema, resolver=resolver)
except jsonschema.ValidationError:
continue
else:
return self.from_dict(
dct,
schema=possible_schema,
rootschema=rootschema,
default_class=cls,
)
if isinstance(dct, dict):
# TODO: handle schemas for additionalProperties/patternProperties
props = schema.get("properties", {})
kwds = {}
for key, val in dct.items():
if key in props:
val = self.from_dict(val, schema=props[key], rootschema=rootschema)
kwds[key] = val
return cls(**kwds)
elif isinstance(dct, list):
item_schema = schema.get("items", {})
dct = [
self.from_dict(val, schema=item_schema, rootschema=rootschema)
for val in dct
]
return cls(dct)
else:
return cls(dct)

View File

@@ -0,0 +1,148 @@
"""
A Simple server used to show altair graphics from a prompt or script.
This is adapted from the mpld3 package; see
https://github.com/mpld3/mpld3/blob/master/mpld3/_server.py
"""
import sys
import threading
import webbrowser
import socket
from http import server
from io import BytesIO as IO
import itertools
import random
JUPYTER_WARNING = """
Note: if you're in the Jupyter notebook, Chart.serve() is not the best
way to view plots. Consider using Chart.display().
You must interrupt the kernel to cancel this command.
"""
# Mock server used for testing
class MockRequest(object):
def makefile(self, *args, **kwargs):
return IO(b"GET /")
def sendall(self, response):
pass
class MockServer(object):
def __init__(self, ip_port, Handler):
Handler(MockRequest(), ip_port[0], self)
def serve_forever(self):
pass
def server_close(self):
pass
def generate_handler(html, files=None):
if files is None:
files = {}
class MyHandler(server.BaseHTTPRequestHandler):
def do_GET(self):
"""Respond to a GET request."""
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(html.encode())
elif self.path in files:
content_type, content = files[self.path]
self.send_response(200)
self.send_header("Content-type", content_type)
self.end_headers()
self.wfile.write(content.encode())
else:
self.send_error(404)
return MyHandler
def find_open_port(ip, port, n=50):
"""Find an open port near the specified port"""
ports = itertools.chain(
(port + i for i in range(n)), (port + random.randint(-2 * n, 2 * n))
)
for port in ports:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = s.connect_ex((ip, port))
s.close()
if result != 0:
return port
raise ValueError("no open ports found")
def serve(
html,
ip="127.0.0.1",
port=8888,
n_retries=50,
files=None,
jupyter_warning=True,
open_browser=True,
http_server=None,
):
"""Start a server serving the given HTML, and (optionally) open a browser
Parameters
----------
html : string
HTML to serve
ip : string (default = '127.0.0.1')
ip address at which the HTML will be served.
port : int (default = 8888)
the port at which to serve the HTML
n_retries : int (default = 50)
the number of nearby ports to search if the specified port is in use.
files : dictionary (optional)
dictionary of extra content to serve
jupyter_warning : bool (optional)
if True (default), then print a warning if this is used within Jupyter
open_browser : bool (optional)
if True (default), then open a web browser to the given HTML
http_server : class (optional)
optionally specify an HTTPServer class to use for showing the
figure. The default is Python's basic HTTPServer.
"""
port = find_open_port(ip, port, n_retries)
Handler = generate_handler(html, files)
if http_server is None:
srvr = server.HTTPServer((ip, port), Handler)
else:
srvr = http_server((ip, port), Handler)
if jupyter_warning:
try:
__IPYTHON__ # noqa
except NameError:
pass
else:
print(JUPYTER_WARNING)
# Start the server
print("Serving to http://{}:{}/ [Ctrl-C to exit]".format(ip, port))
sys.stdout.flush()
if open_browser:
# Use a thread to open a web browser pointing to the server
def b():
return webbrowser.open("http://{}:{}".format(ip, port))
threading.Thread(target=b).start()
try:
srvr.serve_forever()
except (KeyboardInterrupt, SystemExit):
print("\nstopping Server...")
srvr.server_close()

View File

@@ -0,0 +1,265 @@
import types
import numpy as np
import pandas as pd
import pytest
import altair as alt
from .. import parse_shorthand, update_nested, infer_encoding_types
from ..core import infer_dtype
FAKE_CHANNELS_MODULE = '''
"""Fake channels module for utility tests."""
from altair.utils import schemapi
class FieldChannel(object):
def __init__(self, shorthand, **kwargs):
kwargs['shorthand'] = shorthand
return super(FieldChannel, self).__init__(**kwargs)
class ValueChannel(object):
def __init__(self, value, **kwargs):
kwargs['value'] = value
return super(ValueChannel, self).__init__(**kwargs)
class X(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "x"
class XValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "x"
class Y(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "y"
class YValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "y"
class StrokeWidth(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "strokeWidth"
class StrokeWidthValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "strokeWidth"
'''
@pytest.mark.parametrize(
"value,expected_type",
[
([1, 2, 3], "integer"),
([1.0, 2.0, 3.0], "floating"),
([1, 2.0, 3], "mixed-integer-float"),
(["a", "b", "c"], "string"),
(["a", "b", np.nan], "mixed"),
],
)
def test_infer_dtype(value, expected_type):
assert infer_dtype(value) == expected_type
def test_parse_shorthand():
def check(s, **kwargs):
assert parse_shorthand(s) == kwargs
check("")
# Fields alone
check("foobar", field="foobar")
check("blah:(fd ", field="blah:(fd ")
# Fields with type
check("foobar:quantitative", type="quantitative", field="foobar")
check("foobar:nominal", type="nominal", field="foobar")
check("foobar:ordinal", type="ordinal", field="foobar")
check("foobar:temporal", type="temporal", field="foobar")
check("foobar:geojson", type="geojson", field="foobar")
check("foobar:Q", type="quantitative", field="foobar")
check("foobar:N", type="nominal", field="foobar")
check("foobar:O", type="ordinal", field="foobar")
check("foobar:T", type="temporal", field="foobar")
check("foobar:G", type="geojson", field="foobar")
# Fields with aggregate and/or type
check("average(foobar)", field="foobar", aggregate="average")
check("min(foobar):temporal", type="temporal", field="foobar", aggregate="min")
check("sum(foobar):Q", type="quantitative", field="foobar", aggregate="sum")
# check that invalid arguments are not split-out
check("invalid(blah)", field="invalid(blah)")
check("blah:invalid", field="blah:invalid")
check("invalid(blah):invalid", field="invalid(blah):invalid")
# check parsing in presence of strange characters
check(
"average(a b:(c\nd):Q",
aggregate="average",
field="a b:(c\nd",
type="quantitative",
)
# special case: count doesn't need an argument
check("count()", aggregate="count", type="quantitative")
check("count():O", aggregate="count", type="ordinal")
# time units:
check("month(x)", field="x", timeUnit="month", type="temporal")
check("year(foo):O", field="foo", timeUnit="year", type="ordinal")
check("date(date):quantitative", field="date", timeUnit="date", type="quantitative")
check(
"yearmonthdate(field)", field="field", timeUnit="yearmonthdate", type="temporal"
)
def test_parse_shorthand_with_data():
def check(s, data, **kwargs):
assert parse_shorthand(s, data) == kwargs
data = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"y": ["A", "B", "C", "D", "E"],
"z": pd.date_range("2018-01-01", periods=5, freq="D"),
"t": pd.date_range("2018-01-01", periods=5, freq="D").tz_localize("UTC"),
}
)
check("x", data, field="x", type="quantitative")
check("y", data, field="y", type="nominal")
check("z", data, field="z", type="temporal")
check("t", data, field="t", type="temporal")
check("count(x)", data, field="x", aggregate="count", type="quantitative")
check("count()", data, aggregate="count", type="quantitative")
check("month(z)", data, timeUnit="month", field="z", type="temporal")
check("month(t)", data, timeUnit="month", field="t", type="temporal")
def test_parse_shorthand_all_aggregates():
aggregates = alt.Root._schema["definitions"]["AggregateOp"]["enum"]
for aggregate in aggregates:
shorthand = "{aggregate}(field):Q".format(aggregate=aggregate)
assert parse_shorthand(shorthand) == {
"aggregate": aggregate,
"field": "field",
"type": "quantitative",
}
def test_parse_shorthand_all_timeunits():
timeUnits = []
for loc in ["Local", "Utc"]:
for typ in ["Single", "Multi"]:
defn = loc + typ + "TimeUnit"
timeUnits.extend(alt.Root._schema["definitions"][defn]["enum"])
for timeUnit in timeUnits:
shorthand = "{timeUnit}(field):Q".format(timeUnit=timeUnit)
assert parse_shorthand(shorthand) == {
"timeUnit": timeUnit,
"field": "field",
"type": "quantitative",
}
def test_parse_shorthand_window_count():
shorthand = "count()"
dct = parse_shorthand(
shorthand,
parse_aggregates=False,
parse_window_ops=True,
parse_timeunits=False,
parse_types=False,
)
assert dct == {"op": "count"}
def test_parse_shorthand_all_window_ops():
window_ops = alt.Root._schema["definitions"]["WindowOnlyOp"]["enum"]
aggregates = alt.Root._schema["definitions"]["AggregateOp"]["enum"]
for op in window_ops + aggregates:
shorthand = "{op}(field)".format(op=op)
dct = parse_shorthand(
shorthand,
parse_aggregates=False,
parse_window_ops=True,
parse_timeunits=False,
parse_types=False,
)
assert dct == {"field": "field", "op": op}
def test_update_nested():
original = {"x": {"b": {"foo": 2}, "c": 4}}
update = {"x": {"b": {"foo": 5}, "d": 6}, "y": 40}
output = update_nested(original, update, copy=True)
assert output is not original
assert output == {"x": {"b": {"foo": 5}, "c": 4, "d": 6}, "y": 40}
output2 = update_nested(original, update)
assert output2 is original
assert output == output2
@pytest.fixture
def channels():
channels = types.ModuleType("channels")
exec(FAKE_CHANNELS_MODULE, channels.__dict__)
return channels
def _getargs(*args, **kwargs):
return args, kwargs
def test_infer_encoding_types(channels):
expected = dict(
x=channels.X("xval"),
y=channels.YValue("yval"),
strokeWidth=channels.StrokeWidthValue(value=4),
)
# All positional args
args, kwds = _getargs(
channels.X("xval"), channels.YValue("yval"), channels.StrokeWidthValue(4)
)
assert infer_encoding_types(args, kwds, channels) == expected
# All keyword args
args, kwds = _getargs(x="xval", y=alt.value("yval"), strokeWidth=alt.value(4))
assert infer_encoding_types(args, kwds, channels) == expected
# Mixed positional & keyword
args, kwds = _getargs(
channels.X("xval"), channels.YValue("yval"), strokeWidth=alt.value(4)
)
assert infer_encoding_types(args, kwds, channels) == expected
def test_infer_encoding_types_with_condition(channels):
args, kwds = _getargs(
x=alt.condition("pred1", alt.value(1), alt.value(2)),
y=alt.condition("pred2", alt.value(1), "yval"),
strokeWidth=alt.condition("pred3", "sval", alt.value(2)),
)
expected = dict(
x=channels.XValue(2, condition=channels.XValue(1, test="pred1")),
y=channels.Y("yval", condition=channels.YValue(1, test="pred2")),
strokeWidth=channels.StrokeWidthValue(
2, condition=channels.StrokeWidth("sval", test="pred3")
),
)
assert infer_encoding_types(args, kwds, channels) == expected

View File

@@ -0,0 +1,139 @@
import os
import pytest
import pandas as pd
from toolz import pipe
from ..data import limit_rows, MaxRowsError, sample, to_values, to_json, to_csv
def _create_dataframe(N):
data = pd.DataFrame({"x": range(N), "y": range(N)})
return data
def _create_data_with_values(N):
data = {"values": [{"x": i, "y": i + 1} for i in range(N)]}
return data
def test_limit_rows():
"""Test the limit_rows data transformer."""
data = _create_dataframe(10)
result = limit_rows(data, max_rows=20)
assert data is result
with pytest.raises(MaxRowsError):
pipe(data, limit_rows(max_rows=5))
data = _create_data_with_values(10)
result = pipe(data, limit_rows(max_rows=20))
assert data is result
with pytest.raises(MaxRowsError):
limit_rows(data, max_rows=5)
def test_sample():
"""Test the sample data transformer."""
data = _create_dataframe(20)
result = pipe(data, sample(n=10))
assert len(result) == 10
assert isinstance(result, pd.DataFrame)
data = _create_data_with_values(20)
result = sample(data, n=10)
assert isinstance(result, dict)
assert "values" in result
assert len(result["values"]) == 10
data = _create_dataframe(20)
result = pipe(data, sample(frac=0.5))
assert len(result) == 10
assert isinstance(result, pd.DataFrame)
data = _create_data_with_values(20)
result = sample(data, frac=0.5)
assert isinstance(result, dict)
assert "values" in result
assert len(result["values"]) == 10
def test_to_values():
"""Test the to_values data transformer."""
data = _create_dataframe(10)
result = pipe(data, to_values)
assert result == {"values": data.to_dict(orient="records")}
def test_type_error():
"""Ensure that TypeError is raised for types other than dict/DataFrame."""
for f in (sample, limit_rows, to_values):
with pytest.raises(TypeError):
pipe(0, f)
def test_dataframe_to_json():
"""Test to_json
- make certain the filename is deterministic
- make certain the file contents match the data
"""
data = _create_dataframe(10)
try:
result1 = pipe(data, to_json)
result2 = pipe(data, to_json)
filename = result1["url"]
output = pd.read_json(filename)
finally:
os.remove(filename)
assert result1 == result2
assert output.equals(data)
def test_dict_to_json():
"""Test to_json
- make certain the filename is deterministic
- make certain the file contents match the data
"""
data = _create_data_with_values(10)
try:
result1 = pipe(data, to_json)
result2 = pipe(data, to_json)
filename = result1["url"]
output = pd.read_json(filename).to_dict(orient="records")
finally:
os.remove(filename)
assert result1 == result2
assert data == {"values": output}
def test_dataframe_to_csv():
"""Test to_csv with dataframe input
- make certain the filename is deterministic
- make certain the file contents match the data
"""
data = _create_dataframe(10)
try:
result1 = pipe(data, to_csv)
result2 = pipe(data, to_csv)
filename = result1["url"]
output = pd.read_csv(filename)
finally:
os.remove(filename)
assert result1 == result2
assert output.equals(data)
def test_dict_to_csv():
"""Test to_csv with dict input
- make certain the filename is deterministic
- make certain the file contents match the data
"""
data = _create_data_with_values(10)
try:
result1 = pipe(data, to_csv)
result2 = pipe(data, to_csv)
filename = result1["url"]
output = pd.read_csv(filename).to_dict(orient="records")
finally:
os.remove(filename)
assert result1 == result2
assert data == {"values": output}

View File

@@ -0,0 +1,24 @@
import pytest
import altair as alt
from altair.utils import AltairDeprecationWarning
from altair.utils.deprecation import _deprecate, deprecated
def test_deprecated_class():
OldChart = _deprecate(alt.Chart, "OldChart")
with pytest.warns(AltairDeprecationWarning) as record:
OldChart()
assert "alt.OldChart" in record[0].message.args[0]
assert "alt.Chart" in record[0].message.args[0]
def test_deprecation_decorator():
@deprecated(message="func is deprecated")
def func(x):
return x + 1
with pytest.warns(AltairDeprecationWarning) as record:
y = func(1)
assert y == 2
assert record[0].message.args[0] == "func is deprecated"

View File

@@ -0,0 +1,30 @@
from ..execeval import eval_block
HAS_RETURN = """
x = 4
y = 2 * x
3 * y
"""
NO_RETURN = """
x = 4
y = 2 * x
z = 3 * y
"""
def test_eval_block_with_return():
_globals = {}
result = eval_block(HAS_RETURN, _globals)
assert result == 24
assert _globals["x"] == 4
assert _globals["y"] == 8
def test_eval_block_without_return():
_globals = {}
result = eval_block(NO_RETURN, _globals)
assert result is None
assert _globals["x"] == 4
assert _globals["y"] == 8
assert _globals["z"] == 24

View File

@@ -0,0 +1,52 @@
import pytest
from ..html import spec_to_html
@pytest.fixture
def spec():
return {
"data": {"url": "data.json"},
"mark": "point",
"encoding": {
"x": {"field": "x", "type": "quantitative"},
"y": {"field": "y", "type": "quantitative"},
},
}
@pytest.mark.parametrize("requirejs", [True, False])
@pytest.mark.parametrize("fullhtml", [True, False])
def test_spec_to_html(requirejs, fullhtml, spec):
# We can't test that the html actually renders, but we'll test aspects of
# it to make certain that the keywords are respected.
vegaembed_version = ("3.12",)
vegalite_version = ("3.0",)
vega_version = "4.0"
html = spec_to_html(
spec,
mode="vega-lite",
requirejs=requirejs,
fullhtml=fullhtml,
vegalite_version=vegalite_version,
vegaembed_version=vegaembed_version,
vega_version=vega_version,
)
html = html.strip()
if fullhtml:
assert html.startswith("<!DOCTYPE html>")
assert html.endswith("</html>")
else:
assert html.startswith("<style>")
assert html.endswith("</script>")
if requirejs:
assert "require(" in html
else:
assert "require(" not in html
assert "vega-lite@{}".format(vegalite_version) in html
assert "vega@{}".format(vega_version) in html
assert "vega-embed@{}".format(vegaembed_version) in html

View File

@@ -0,0 +1,207 @@
import pytest
import altair as alt
from ..mimebundle import spec_to_mimebundle
@pytest.fixture
def require_altair_saver():
try:
import altair_saver # noqa: F401
except ImportError:
pytest.skip("altair_saver not importable; cannot run saver tests")
@pytest.fixture
def vegalite_spec():
return {
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
"description": "A simple bar chart with embedded data.",
"data": {
"values": [
{"a": "A", "b": 28},
{"a": "B", "b": 55},
{"a": "C", "b": 43},
{"a": "D", "b": 91},
{"a": "E", "b": 81},
{"a": "F", "b": 53},
{"a": "G", "b": 19},
{"a": "H", "b": 87},
{"a": "I", "b": 52},
]
},
"mark": "bar",
"encoding": {
"x": {"field": "a", "type": "ordinal"},
"y": {"field": "b", "type": "quantitative"},
},
}
@pytest.fixture
def vega_spec():
return {
"$schema": "https://vega.github.io/schema/vega/v5.json",
"axes": [
{
"aria": False,
"domain": False,
"grid": True,
"gridScale": "x",
"labels": False,
"maxExtent": 0,
"minExtent": 0,
"orient": "left",
"scale": "y",
"tickCount": {"signal": "ceil(height/40)"},
"ticks": False,
"zindex": 0,
},
{
"grid": False,
"labelAlign": "right",
"labelAngle": 270,
"labelBaseline": "middle",
"orient": "bottom",
"scale": "x",
"title": "a",
"zindex": 0,
},
{
"grid": False,
"labelOverlap": True,
"orient": "left",
"scale": "y",
"tickCount": {"signal": "ceil(height/40)"},
"title": "b",
"zindex": 0,
},
],
"background": "white",
"data": [
{
"name": "source_0",
"values": [
{"a": "A", "b": 28},
{"a": "B", "b": 55},
{"a": "C", "b": 43},
{"a": "D", "b": 91},
{"a": "E", "b": 81},
{"a": "F", "b": 53},
{"a": "G", "b": 19},
{"a": "H", "b": 87},
{"a": "I", "b": 52},
],
},
{
"name": "data_0",
"source": "source_0",
"transform": [
{
"expr": 'isValid(datum["b"]) && isFinite(+datum["b"])',
"type": "filter",
}
],
},
],
"description": "A simple bar chart with embedded data.",
"height": 200,
"marks": [
{
"encode": {
"update": {
"ariaRoleDescription": {"value": "bar"},
"description": {
"signal": '"a: " + (isValid(datum["a"]) ? datum["a"] : ""+datum["a"]) + "; b: " + (format(datum["b"], ""))'
},
"fill": {"value": "#4c78a8"},
"width": {"band": 1, "scale": "x"},
"x": {"field": "a", "scale": "x"},
"y": {"field": "b", "scale": "y"},
"y2": {"scale": "y", "value": 0},
}
},
"from": {"data": "data_0"},
"name": "marks",
"style": ["bar"],
"type": "rect",
}
],
"padding": 5,
"scales": [
{
"domain": {"data": "data_0", "field": "a", "sort": True},
"name": "x",
"paddingInner": 0.1,
"paddingOuter": 0.05,
"range": {"step": {"signal": "x_step"}},
"type": "band",
},
{
"domain": {"data": "data_0", "field": "b"},
"name": "y",
"nice": True,
"range": [{"signal": "height"}, 0],
"type": "linear",
"zero": True,
},
],
"signals": [
{"name": "x_step", "value": 20},
{
"name": "width",
"update": "bandspace(domain('x').length, 0.1, 0.05) * x_step",
},
],
"style": "cell",
}
def test_vegalite_to_vega_mimebundle(require_altair_saver, vegalite_spec, vega_spec):
# temporay fix for https://github.com/vega/vega-lite/issues/7776
def delete_none(axes):
for axis in axes:
for key, value in list(axis.items()):
if value is None:
del axis[key]
return axes
bundle = spec_to_mimebundle(
spec=vegalite_spec,
format="vega",
mode="vega-lite",
vega_version=alt.VEGA_VERSION,
vegalite_version=alt.VEGALITE_VERSION,
vegaembed_version=alt.VEGAEMBED_VERSION,
)
bundle["application/vnd.vega.v5+json"]["axes"] = delete_none(
bundle["application/vnd.vega.v5+json"]["axes"]
)
assert bundle == {"application/vnd.vega.v5+json": vega_spec}
def test_spec_to_vegalite_mimebundle(vegalite_spec):
bundle = spec_to_mimebundle(
spec=vegalite_spec,
mode="vega-lite",
format="vega-lite",
vegalite_version=alt.VEGALITE_VERSION,
)
assert bundle == {"application/vnd.vegalite.v4+json": vegalite_spec}
def test_spec_to_vega_mimebundle(vega_spec):
bundle = spec_to_mimebundle(
spec=vega_spec, mode="vega", format="vega", vega_version=alt.VEGA_VERSION
)
assert bundle == {"application/vnd.vega.v5+json": vega_spec}
def test_spec_to_json_mimebundle():
bundle = spec_to_mimebundle(
spec=vegalite_spec,
mode="vega-lite",
format="json",
)
assert bundle == {"application/json": vegalite_spec}

View File

@@ -0,0 +1,123 @@
from ..plugin_registry import PluginRegistry
from typing import Callable
class TypedCallableRegistry(PluginRegistry[Callable[[int], int]]):
pass
class GeneralCallableRegistry(PluginRegistry):
_global_settings = {"global_setting": None}
@property
def global_setting(self):
return self._global_settings["global_setting"]
@global_setting.setter
def global_setting(self, val):
self._global_settings["global_setting"] = val
def test_plugin_registry():
plugins = TypedCallableRegistry()
assert plugins.names() == []
assert plugins.active == ""
assert plugins.get() is None
assert repr(plugins) == "TypedCallableRegistry(active='', registered=[])"
plugins.register("new_plugin", lambda x: x ** 2)
assert plugins.names() == ["new_plugin"]
assert plugins.active == ""
assert plugins.get() is None
assert repr(plugins) == (
"TypedCallableRegistry(active='', " "registered=['new_plugin'])"
)
plugins.enable("new_plugin")
assert plugins.names() == ["new_plugin"]
assert plugins.active == "new_plugin"
assert plugins.get()(3) == 9
assert repr(plugins) == (
"TypedCallableRegistry(active='new_plugin', " "registered=['new_plugin'])"
)
def test_plugin_registry_extra_options():
plugins = GeneralCallableRegistry()
plugins.register("metadata_plugin", lambda x, p=2: x ** p)
plugins.enable("metadata_plugin")
assert plugins.get()(3) == 9
plugins.enable("metadata_plugin", p=3)
assert plugins.active == "metadata_plugin"
assert plugins.get()(3) == 27
# enabling without changing name
plugins.enable(p=2)
assert plugins.active == "metadata_plugin"
assert plugins.get()(3) == 9
def test_plugin_registry_global_settings():
plugins = GeneralCallableRegistry()
# we need some default plugin, but we won't do anything with it
plugins.register("default", lambda x: x)
plugins.enable("default")
# default value of the global flag
assert plugins.global_setting is None
# enabling changes the global state, not the options
plugins.enable(global_setting=True)
assert plugins.global_setting is True
assert plugins._options == {}
# context manager changes global state temporarily
with plugins.enable(global_setting="temp"):
assert plugins.global_setting == "temp"
assert plugins._options == {}
assert plugins.global_setting is True
assert plugins._options == {}
def test_plugin_registry_context():
plugins = GeneralCallableRegistry()
plugins.register("default", lambda x, p=2: x ** p)
# At first there is no plugin enabled
assert plugins.active == ""
assert plugins.options == {}
# Make sure the context is set and reset correctly
with plugins.enable("default", p=6):
assert plugins.active == "default"
assert plugins.options == {"p": 6}
assert plugins.active == ""
assert plugins.options == {}
# Make sure the context is reset even if there is an error
try:
with plugins.enable("default", p=6):
assert plugins.active == "default"
assert plugins.options == {"p": 6}
raise ValueError()
except ValueError:
pass
assert plugins.active == ""
assert plugins.options == {}
# Enabling without specifying name uses current name
plugins.enable("default", p=2)
with plugins.enable(p=6):
assert plugins.active == "default"
assert plugins.options == {"p": 6}
assert plugins.active == "default"
assert plugins.options == {"p": 2}

View File

@@ -0,0 +1,351 @@
# The contents of this file are automatically written by
# tools/generate_schema_wrapper.py. Do not modify directly.
import copy
import io
import json
import jsonschema
import pickle
import pytest
import numpy as np
from ..schemapi import (
UndefinedType,
SchemaBase,
Undefined,
_FromDict,
SchemaValidationError,
)
# Make tests inherit from _TestSchema, so that when we test from_dict it won't
# try to use SchemaBase objects defined elsewhere as wrappers.
class _TestSchema(SchemaBase):
@classmethod
def _default_wrapper_classes(cls):
return _TestSchema.__subclasses__()
class MySchema(_TestSchema):
_schema = {
"definitions": {
"StringMapping": {
"type": "object",
"additionalProperties": {"type": "string"},
},
"StringArray": {"type": "array", "items": {"type": "string"}},
},
"properties": {
"a": {"$ref": "#/definitions/StringMapping"},
"a2": {"type": "object", "additionalProperties": {"type": "number"}},
"b": {"$ref": "#/definitions/StringArray"},
"b2": {"type": "array", "items": {"type": "number"}},
"c": {"type": ["string", "number"]},
"d": {
"anyOf": [
{"$ref": "#/definitions/StringMapping"},
{"$ref": "#/definitions/StringArray"},
]
},
"e": {"items": [{"type": "string"}, {"type": "string"}]},
},
}
class StringMapping(_TestSchema):
_schema = {"$ref": "#/definitions/StringMapping"}
_rootschema = MySchema._schema
class StringArray(_TestSchema):
_schema = {"$ref": "#/definitions/StringArray"}
_rootschema = MySchema._schema
class Derived(_TestSchema):
_schema = {
"definitions": {
"Foo": {"type": "object", "properties": {"d": {"type": "string"}}},
"Bar": {"type": "string", "enum": ["A", "B"]},
},
"type": "object",
"additionalProperties": False,
"properties": {
"a": {"type": "integer"},
"b": {"type": "string"},
"c": {"$ref": "#/definitions/Foo"},
},
}
class Foo(_TestSchema):
_schema = {"$ref": "#/definitions/Foo"}
_rootschema = Derived._schema
class Bar(_TestSchema):
_schema = {"$ref": "#/definitions/Bar"}
_rootschema = Derived._schema
class SimpleUnion(_TestSchema):
_schema = {"anyOf": [{"type": "integer"}, {"type": "string"}]}
class DefinitionUnion(_TestSchema):
_schema = {"anyOf": [{"$ref": "#/definitions/Foo"}, {"$ref": "#/definitions/Bar"}]}
_rootschema = Derived._schema
class SimpleArray(_TestSchema):
_schema = {
"type": "array",
"items": {"anyOf": [{"type": "integer"}, {"type": "string"}]},
}
class InvalidProperties(_TestSchema):
_schema = {
"type": "object",
"properties": {"for": {}, "as": {}, "vega-lite": {}, "$schema": {}},
}
def test_construct_multifaceted_schema():
dct = {
"a": {"foo": "bar"},
"a2": {"foo": 42},
"b": ["a", "b", "c"],
"b2": [1, 2, 3],
"c": 42,
"d": ["x", "y", "z"],
"e": ["a", "b"],
}
myschema = MySchema.from_dict(dct)
assert myschema.to_dict() == dct
myschema2 = MySchema(**dct)
assert myschema2.to_dict() == dct
assert isinstance(myschema.a, StringMapping)
assert isinstance(myschema.a2, dict)
assert isinstance(myschema.b, StringArray)
assert isinstance(myschema.b2, list)
assert isinstance(myschema.d, StringArray)
def test_schema_cases():
assert Derived(a=4, b="yo").to_dict() == {"a": 4, "b": "yo"}
assert Derived(a=4, c={"d": "hey"}).to_dict() == {"a": 4, "c": {"d": "hey"}}
assert Derived(a=4, b="5", c=Foo(d="val")).to_dict() == {
"a": 4,
"b": "5",
"c": {"d": "val"},
}
assert Foo(d="hello", f=4).to_dict() == {"d": "hello", "f": 4}
assert Derived().to_dict() == {}
assert Foo().to_dict() == {}
with pytest.raises(jsonschema.ValidationError):
# a needs to be an integer
Derived(a="yo").to_dict()
with pytest.raises(jsonschema.ValidationError):
# Foo.d needs to be a string
Derived(c=Foo(4)).to_dict()
with pytest.raises(jsonschema.ValidationError):
# no additional properties allowed
Derived(foo="bar").to_dict()
def test_round_trip():
D = {"a": 4, "b": "yo"}
assert Derived.from_dict(D).to_dict() == D
D = {"a": 4, "c": {"d": "hey"}}
assert Derived.from_dict(D).to_dict() == D
D = {"a": 4, "b": "5", "c": {"d": "val"}}
assert Derived.from_dict(D).to_dict() == D
D = {"d": "hello", "f": 4}
assert Foo.from_dict(D).to_dict() == D
def test_from_dict():
D = {"a": 4, "b": "5", "c": {"d": "val"}}
obj = Derived.from_dict(D)
assert obj.a == 4
assert obj.b == "5"
assert isinstance(obj.c, Foo)
def test_simple_type():
assert SimpleUnion(4).to_dict() == 4
def test_simple_array():
assert SimpleArray([4, 5, "six"]).to_dict() == [4, 5, "six"]
assert SimpleArray.from_dict(list("abc")).to_dict() == list("abc")
def test_definition_union():
obj = DefinitionUnion.from_dict("A")
assert isinstance(obj, Bar)
assert obj.to_dict() == "A"
obj = DefinitionUnion.from_dict("B")
assert isinstance(obj, Bar)
assert obj.to_dict() == "B"
obj = DefinitionUnion.from_dict({"d": "yo"})
assert isinstance(obj, Foo)
assert obj.to_dict() == {"d": "yo"}
def test_invalid_properties():
dct = {"for": 2, "as": 3, "vega-lite": 4, "$schema": 5}
invalid = InvalidProperties.from_dict(dct)
assert invalid["for"] == 2
assert invalid["as"] == 3
assert invalid["vega-lite"] == 4
assert invalid["$schema"] == 5
assert invalid.to_dict() == dct
def test_undefined_singleton():
assert Undefined is UndefinedType()
@pytest.fixture
def dct():
return {
"a": {"foo": "bar"},
"a2": {"foo": 42},
"b": ["a", "b", "c"],
"b2": [1, 2, 3],
"c": 42,
"d": ["x", "y", "z"],
}
def test_copy_method(dct):
myschema = MySchema.from_dict(dct)
# Make sure copy is deep
copy = myschema.copy(deep=True)
copy["a"]["foo"] = "new value"
copy["b"] = ["A", "B", "C"]
copy["c"] = 164
assert myschema.to_dict() == dct
# If we ignore a value, changing the copy changes the original
copy = myschema.copy(deep=True, ignore=["a"])
copy["a"]["foo"] = "new value"
copy["b"] = ["A", "B", "C"]
copy["c"] = 164
mydct = myschema.to_dict()
assert mydct["a"]["foo"] == "new value"
assert mydct["b"][0] == dct["b"][0]
assert mydct["c"] == dct["c"]
# If copy is not deep, then changing copy below top level changes original
copy = myschema.copy(deep=False)
copy["a"]["foo"] = "baz"
copy["b"] = ["A", "B", "C"]
copy["c"] = 164
mydct = myschema.to_dict()
assert mydct["a"]["foo"] == "baz"
assert mydct["b"] == dct["b"]
assert mydct["c"] == dct["c"]
def test_copy_module(dct):
myschema = MySchema.from_dict(dct)
cp = copy.deepcopy(myschema)
cp["a"]["foo"] = "new value"
cp["b"] = ["A", "B", "C"]
cp["c"] = 164
assert myschema.to_dict() == dct
def test_attribute_error():
m = MySchema()
with pytest.raises(AttributeError) as err:
m.invalid_attribute
assert str(err.value) == (
"'MySchema' object has no attribute " "'invalid_attribute'"
)
def test_to_from_json(dct):
json_str = MySchema.from_dict(dct).to_json()
new_dct = MySchema.from_json(json_str).to_dict()
assert new_dct == dct
def test_to_from_pickle(dct):
myschema = MySchema.from_dict(dct)
output = io.BytesIO()
pickle.dump(myschema, output)
output.seek(0)
myschema_new = pickle.load(output)
assert myschema_new.to_dict() == dct
def test_class_with_no_schema():
class BadSchema(SchemaBase):
pass
with pytest.raises(ValueError) as err:
BadSchema(4)
assert str(err.value).startswith("Cannot instantiate object")
@pytest.mark.parametrize("use_json", [True, False])
def test_hash_schema(use_json):
classes = _TestSchema._default_wrapper_classes()
for cls in classes:
hsh1 = _FromDict.hash_schema(cls._schema, use_json=use_json)
hsh2 = _FromDict.hash_schema(cls._schema, use_json=use_json)
assert hsh1 == hsh2
assert hash(hsh1) == hash(hsh2)
def test_schema_validation_error():
try:
MySchema(a={"foo": 4})
the_err = None
except jsonschema.ValidationError as err:
the_err = err
assert isinstance(the_err, SchemaValidationError)
message = str(the_err)
assert message.startswith("Invalid specification")
assert "test_schemapi.MySchema->a" in message
assert "validating {!r}".format(the_err.validator) in message
assert the_err.message in message
def test_serialize_numpy_types():
m = MySchema(
a={"date": np.datetime64("2019-01-01")},
a2={"int64": np.int64(1), "float64": np.float64(2)},
b2=np.arange(4),
)
out = m.to_json()
dct = json.loads(out)
assert dct == {
"a": {"date": "2019-01-01T00:00:00"},
"a2": {"int64": 1, "float64": 2},
"b2": [0, 1, 2, 3],
}

View File

@@ -0,0 +1,10 @@
"""
Test http server
"""
from altair.utils.server import serve, MockServer
def test_serve():
html = "<html><title>Title</title><body><p>Content</p></body></html>"
serve(html, open_browser=False, http_server=MockServer)

View File

@@ -0,0 +1,192 @@
import pytest
import warnings
import json
import numpy as np
import pandas as pd
from .. import infer_vegalite_type, sanitize_dataframe
def test_infer_vegalite_type():
def _check(arr, typ):
assert infer_vegalite_type(arr) == typ
_check(np.arange(5, dtype=float), "quantitative")
_check(np.arange(5, dtype=int), "quantitative")
_check(np.zeros(5, dtype=bool), "nominal")
_check(pd.date_range("2012", "2013"), "temporal")
_check(pd.timedelta_range(365, periods=12), "temporal")
nulled = pd.Series(np.random.randint(10, size=10))
nulled[0] = None
_check(nulled, "quantitative")
_check(["a", "b", "c"], "nominal")
if hasattr(pytest, "warns"): # added in pytest 2.8
with pytest.warns(UserWarning):
_check([], "nominal")
else:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
_check([], "nominal")
def test_sanitize_dataframe():
# create a dataframe with various types
df = pd.DataFrame(
{
"s": list("abcde"),
"f": np.arange(5, dtype=float),
"i": np.arange(5, dtype=int),
"b": np.array([True, False, True, True, False]),
"d": pd.date_range("2012-01-01", periods=5, freq="H"),
"c": pd.Series(list("ababc"), dtype="category"),
"c2": pd.Series([1, "A", 2.5, "B", None], dtype="category"),
"o": pd.Series([np.array(i) for i in range(5)]),
"p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"),
}
)
# add some nulls
df.iloc[0, df.columns.get_loc("s")] = None
df.iloc[0, df.columns.get_loc("f")] = np.nan
df.iloc[0, df.columns.get_loc("d")] = pd.NaT
df.iloc[0, df.columns.get_loc("o")] = np.array(np.nan)
# JSON serialize. This will fail on non-sanitized dataframes
print(df[["s", "c2"]])
df_clean = sanitize_dataframe(df)
print(df_clean[["s", "c2"]])
print(df_clean[["s", "c2"]].to_dict())
s = json.dumps(df_clean.to_dict(orient="records"))
print(s)
# Re-construct pandas dataframe
df2 = pd.read_json(s)
# Re-order the columns to match df
df2 = df2[df.columns]
# Re-apply original types
for col in df:
if str(df[col].dtype).startswith("datetime"):
# astype(datetime) introduces time-zone issues:
# to_datetime() does not.
utc = isinstance(df[col].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype)
df2[col] = pd.to_datetime(df2[col], utc=utc)
else:
df2[col] = df2[col].astype(df[col].dtype)
# pandas doesn't properly recognize np.array(np.nan), so change it here
df.iloc[0, df.columns.get_loc("o")] = np.nan
assert df.equals(df2)
def test_sanitize_dataframe_colnames():
df = pd.DataFrame(np.arange(12).reshape(4, 3))
# Test that RangeIndex is converted to strings
df = sanitize_dataframe(df)
assert [isinstance(col, str) for col in df.columns]
# Test that non-string columns result in an error
df.columns = [4, "foo", "bar"]
with pytest.raises(ValueError) as err:
sanitize_dataframe(df)
assert str(err.value).startswith("Dataframe contains invalid column name: 4.")
def test_sanitize_dataframe_timedelta():
df = pd.DataFrame({"r": pd.timedelta_range(start="1 day", periods=4)})
with pytest.raises(ValueError) as err:
sanitize_dataframe(df)
assert str(err.value).startswith('Field "r" has type "timedelta')
def test_sanitize_dataframe_infs():
df = pd.DataFrame({"x": [0, 1, 2, np.inf, -np.inf, np.nan]})
df_clean = sanitize_dataframe(df)
assert list(df_clean.dtypes) == [object]
assert list(df_clean["x"]) == [0, 1, 2, None, None, None]
@pytest.mark.skipif(
not hasattr(pd, "Int64Dtype"),
reason="Nullable integers not supported in pandas v{}".format(pd.__version__),
)
def test_sanitize_nullable_integers():
df = pd.DataFrame(
{
"int_np": [1, 2, 3, 4, 5],
"int64": pd.Series([1, 2, 3, None, 5], dtype="UInt8"),
"int64_nan": pd.Series([1, 2, 3, float("nan"), 5], dtype="Int64"),
"float": [1.0, 2.0, 3.0, 4, 5.0],
"float_null": [1, 2, None, 4, 5],
"float_inf": [1, 2, None, 4, (float("inf"))],
}
)
df_clean = sanitize_dataframe(df)
assert {col.dtype.name for _, col in df_clean.iteritems()} == {"object"}
result_python = {col_name: list(col) for col_name, col in df_clean.iteritems()}
assert result_python == {
"int_np": [1, 2, 3, 4, 5],
"int64": [1, 2, 3, None, 5],
"int64_nan": [1, 2, 3, None, 5],
"float": [1.0, 2.0, 3.0, 4.0, 5.0],
"float_null": [1.0, 2.0, None, 4.0, 5.0],
"float_inf": [1.0, 2.0, None, 4.0, None],
}
@pytest.mark.skipif(
not hasattr(pd, "StringDtype"),
reason="dedicated String dtype not supported in pandas v{}".format(pd.__version__),
)
def test_sanitize_string_dtype():
df = pd.DataFrame(
{
"string_object": ["a", "b", "c", "d"],
"string_string": pd.array(["a", "b", "c", "d"], dtype="string"),
"string_object_null": ["a", "b", None, "d"],
"string_string_null": pd.array(["a", "b", None, "d"], dtype="string"),
}
)
df_clean = sanitize_dataframe(df)
assert {col.dtype.name for _, col in df_clean.iteritems()} == {"object"}
result_python = {col_name: list(col) for col_name, col in df_clean.iteritems()}
assert result_python == {
"string_object": ["a", "b", "c", "d"],
"string_string": ["a", "b", "c", "d"],
"string_object_null": ["a", "b", None, "d"],
"string_string_null": ["a", "b", None, "d"],
}
@pytest.mark.skipif(
not hasattr(pd, "BooleanDtype"),
reason="Nullable boolean dtype not supported in pandas v{}".format(pd.__version__),
)
def test_sanitize_boolean_dtype():
df = pd.DataFrame(
{
"bool_none": pd.array([True, False, None], dtype="boolean"),
"none": pd.array([None, None, None], dtype="boolean"),
"bool": pd.array([True, False, True], dtype="boolean"),
}
)
df_clean = sanitize_dataframe(df)
assert {col.dtype.name for _, col in df_clean.iteritems()} == {"object"}
result_python = {col_name: list(col) for col_name, col in df_clean.iteritems()}
assert result_python == {
"bool_none": [True, False, None],
"none": [None, None, None],
"bool": [True, False, True],
}

View File

@@ -0,0 +1,10 @@
"""Utilities for registering and working with themes"""
from .plugin_registry import PluginRegistry
from typing import Callable
ThemeType = Callable[..., dict]
class ThemeRegistry(PluginRegistry[ThemeType]):
pass