mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-22 02:23:48 +00:00
732 lines
21 KiB
Python
732 lines
21 KiB
Python
"""
|
|
Utility routines
|
|
"""
|
|
from collections.abc import Mapping
|
|
from copy import deepcopy
|
|
import json
|
|
import itertools
|
|
import re
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
|
|
import jsonschema
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from .schemapi import SchemaBase, Undefined
|
|
|
|
try:
|
|
from pandas.api.types import infer_dtype as _infer_dtype
|
|
except ImportError:
|
|
# Import for pandas < 0.20.0
|
|
from pandas.lib import infer_dtype as _infer_dtype
|
|
|
|
|
|
def infer_dtype(value):
|
|
"""Infer the dtype of the value.
|
|
|
|
This is a compatibility function for pandas infer_dtype,
|
|
with skipna=False regardless of the pandas version.
|
|
"""
|
|
if not hasattr(infer_dtype, "_supports_skipna"):
|
|
try:
|
|
_infer_dtype([1], skipna=False)
|
|
except TypeError:
|
|
# pandas < 0.21.0 don't support skipna keyword
|
|
infer_dtype._supports_skipna = False
|
|
else:
|
|
infer_dtype._supports_skipna = True
|
|
if infer_dtype._supports_skipna:
|
|
return _infer_dtype(value, skipna=False)
|
|
else:
|
|
return _infer_dtype(value)
|
|
|
|
|
|
TYPECODE_MAP = {
|
|
"ordinal": "O",
|
|
"nominal": "N",
|
|
"quantitative": "Q",
|
|
"temporal": "T",
|
|
"geojson": "G",
|
|
}
|
|
|
|
INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()}
|
|
|
|
|
|
# aggregates from vega-lite version 4.6.0
|
|
AGGREGATES = [
|
|
"argmax",
|
|
"argmin",
|
|
"average",
|
|
"count",
|
|
"distinct",
|
|
"max",
|
|
"mean",
|
|
"median",
|
|
"min",
|
|
"missing",
|
|
"product",
|
|
"q1",
|
|
"q3",
|
|
"ci0",
|
|
"ci1",
|
|
"stderr",
|
|
"stdev",
|
|
"stdevp",
|
|
"sum",
|
|
"valid",
|
|
"values",
|
|
"variance",
|
|
"variancep",
|
|
]
|
|
|
|
# window aggregates from vega-lite version 4.6.0
|
|
WINDOW_AGGREGATES = [
|
|
"row_number",
|
|
"rank",
|
|
"dense_rank",
|
|
"percent_rank",
|
|
"cume_dist",
|
|
"ntile",
|
|
"lag",
|
|
"lead",
|
|
"first_value",
|
|
"last_value",
|
|
"nth_value",
|
|
]
|
|
|
|
# timeUnits from vega-lite version 4.17.0
|
|
TIMEUNITS = [
|
|
"year",
|
|
"quarter",
|
|
"month",
|
|
"week",
|
|
"day",
|
|
"dayofyear",
|
|
"date",
|
|
"hours",
|
|
"minutes",
|
|
"seconds",
|
|
"milliseconds",
|
|
"yearquarter",
|
|
"yearquartermonth",
|
|
"yearmonth",
|
|
"yearmonthdate",
|
|
"yearmonthdatehours",
|
|
"yearmonthdatehoursminutes",
|
|
"yearmonthdatehoursminutesseconds",
|
|
"yearweek",
|
|
"yearweekday",
|
|
"yearweekdayhours",
|
|
"yearweekdayhoursminutes",
|
|
"yearweekdayhoursminutesseconds",
|
|
"yeardayofyear",
|
|
"quartermonth",
|
|
"monthdate",
|
|
"monthdatehours",
|
|
"monthdatehoursminutes",
|
|
"monthdatehoursminutesseconds",
|
|
"weekday",
|
|
"weeksdayhours",
|
|
"weekdayhoursminutes",
|
|
"weekdayhoursminutesseconds",
|
|
"dayhours",
|
|
"dayhoursminutes",
|
|
"dayhoursminutesseconds",
|
|
"hoursminutes",
|
|
"hoursminutesseconds",
|
|
"minutesseconds",
|
|
"secondsmilliseconds",
|
|
"utcyear",
|
|
"utcquarter",
|
|
"utcmonth",
|
|
"utcweek",
|
|
"utcday",
|
|
"utcdayofyear",
|
|
"utcdate",
|
|
"utchours",
|
|
"utcminutes",
|
|
"utcseconds",
|
|
"utcmilliseconds",
|
|
"utcyearquarter",
|
|
"utcyearquartermonth",
|
|
"utcyearmonth",
|
|
"utcyearmonthdate",
|
|
"utcyearmonthdatehours",
|
|
"utcyearmonthdatehoursminutes",
|
|
"utcyearmonthdatehoursminutesseconds",
|
|
"utcyearweek",
|
|
"utcyearweekday",
|
|
"utcyearweekdayhours",
|
|
"utcyearweekdayhoursminutes",
|
|
"utcyearweekdayhoursminutesseconds",
|
|
"utcyeardayofyear",
|
|
"utcquartermonth",
|
|
"utcmonthdate",
|
|
"utcmonthdatehours",
|
|
"utcmonthdatehoursminutes",
|
|
"utcmonthdatehoursminutesseconds",
|
|
"utcweekday",
|
|
"utcweeksdayhours",
|
|
"utcweekdayhoursminutes",
|
|
"utcweekdayhoursminutesseconds",
|
|
"utcdayhours",
|
|
"utcdayhoursminutes",
|
|
"utcdayhoursminutesseconds",
|
|
"utchoursminutes",
|
|
"utchoursminutesseconds",
|
|
"utcminutesseconds",
|
|
"utcsecondsmilliseconds",
|
|
]
|
|
|
|
|
|
def infer_vegalite_type(data):
|
|
"""
|
|
From an array-like input, infer the correct vega typecode
|
|
('ordinal', 'nominal', 'quantitative', or 'temporal')
|
|
|
|
Parameters
|
|
----------
|
|
data: Numpy array or Pandas Series
|
|
"""
|
|
# Otherwise, infer based on the dtype of the input
|
|
typ = infer_dtype(data)
|
|
|
|
# TODO: Once this returns 'O', please update test_select_x and test_select_y in test_api.py
|
|
|
|
if typ in [
|
|
"floating",
|
|
"mixed-integer-float",
|
|
"integer",
|
|
"mixed-integer",
|
|
"complex",
|
|
]:
|
|
return "quantitative"
|
|
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
|
|
return "nominal"
|
|
elif typ in [
|
|
"datetime",
|
|
"datetime64",
|
|
"timedelta",
|
|
"timedelta64",
|
|
"date",
|
|
"time",
|
|
"period",
|
|
]:
|
|
return "temporal"
|
|
else:
|
|
warnings.warn(
|
|
"I don't know how to infer vegalite type from '{}'. "
|
|
"Defaulting to nominal.".format(typ)
|
|
)
|
|
return "nominal"
|
|
|
|
|
|
def merge_props_geom(feat):
|
|
"""
|
|
Merge properties with geometry
|
|
* Overwrites 'type' and 'geometry' entries if existing
|
|
"""
|
|
|
|
geom = {k: feat[k] for k in ("type", "geometry")}
|
|
try:
|
|
feat["properties"].update(geom)
|
|
props_geom = feat["properties"]
|
|
except (AttributeError, KeyError):
|
|
# AttributeError when 'properties' equals None
|
|
# KeyError when 'properties' is non-existing
|
|
props_geom = geom
|
|
|
|
return props_geom
|
|
|
|
|
|
def sanitize_geo_interface(geo):
|
|
"""Santize a geo_interface to prepare it for serialization.
|
|
|
|
* Make a copy
|
|
* Convert type array or _Array to list
|
|
* Convert tuples to lists (using json.loads/dumps)
|
|
* Merge properties with geometry
|
|
"""
|
|
|
|
geo = deepcopy(geo)
|
|
|
|
# convert type _Array or array to list
|
|
for key in geo.keys():
|
|
if str(type(geo[key]).__name__).startswith(("_Array", "array")):
|
|
geo[key] = geo[key].tolist()
|
|
|
|
# convert (nested) tuples to lists
|
|
geo = json.loads(json.dumps(geo))
|
|
|
|
# sanitize features
|
|
if geo["type"] == "FeatureCollection":
|
|
geo = geo["features"]
|
|
if len(geo) > 0:
|
|
for idx, feat in enumerate(geo):
|
|
geo[idx] = merge_props_geom(feat)
|
|
elif geo["type"] == "Feature":
|
|
geo = merge_props_geom(geo)
|
|
else:
|
|
geo = {"type": "Feature", "geometry": geo}
|
|
|
|
return geo
|
|
|
|
|
|
def sanitize_dataframe(df): # noqa: C901
|
|
"""Sanitize a DataFrame to prepare it for serialization.
|
|
|
|
* Make a copy
|
|
* Convert RangeIndex columns to strings
|
|
* Raise ValueError if column names are not strings
|
|
* Raise ValueError if it has a hierarchical index.
|
|
* Convert categoricals to strings.
|
|
* Convert np.bool_ dtypes to Python bool objects
|
|
* Convert np.int dtypes to Python int objects
|
|
* Convert floats to objects and replace NaNs/infs with None.
|
|
* Convert DateTime dtypes into appropriate string representations
|
|
* Convert Nullable integers to objects and replace NaN with None
|
|
* Convert Nullable boolean to objects and replace NaN with None
|
|
* convert dedicated string column to objects and replace NaN with None
|
|
* Raise a ValueError for TimeDelta dtypes
|
|
"""
|
|
df = df.copy()
|
|
|
|
if isinstance(df.columns, pd.RangeIndex):
|
|
df.columns = df.columns.astype(str)
|
|
|
|
for col in df.columns:
|
|
if not isinstance(col, str):
|
|
raise ValueError(
|
|
"Dataframe contains invalid column name: {0!r}. "
|
|
"Column names must be strings".format(col)
|
|
)
|
|
|
|
if isinstance(df.index, pd.MultiIndex):
|
|
raise ValueError("Hierarchical indices not supported")
|
|
if isinstance(df.columns, pd.MultiIndex):
|
|
raise ValueError("Hierarchical indices not supported")
|
|
|
|
def to_list_if_array(val):
|
|
if isinstance(val, np.ndarray):
|
|
return val.tolist()
|
|
else:
|
|
return val
|
|
|
|
for col_name, dtype in df.dtypes.iteritems():
|
|
if str(dtype) == "category":
|
|
# XXXX: work around bug in to_json for categorical types
|
|
# https://github.com/pydata/pandas/issues/10778
|
|
col = df[col_name].astype(object)
|
|
df[col_name] = col.where(col.notnull(), None)
|
|
elif str(dtype) == "string":
|
|
# dedicated string datatype (since 1.0)
|
|
# https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type
|
|
col = df[col_name].astype(object)
|
|
df[col_name] = col.where(col.notnull(), None)
|
|
elif str(dtype) == "bool":
|
|
# convert numpy bools to objects; np.bool is not JSON serializable
|
|
df[col_name] = df[col_name].astype(object)
|
|
elif str(dtype) == "boolean":
|
|
# dedicated boolean datatype (since 1.0)
|
|
# https://pandas.io/docs/user_guide/boolean.html
|
|
col = df[col_name].astype(object)
|
|
df[col_name] = col.where(col.notnull(), None)
|
|
elif str(dtype).startswith("datetime"):
|
|
# Convert datetimes to strings. This needs to be a full ISO string
|
|
# with time, which is why we cannot use ``col.astype(str)``.
|
|
# This is because Javascript parses date-only times in UTC, but
|
|
# parses full ISO-8601 dates as local time, and dates in Vega and
|
|
# Vega-Lite are displayed in local time by default.
|
|
# (see https://github.com/altair-viz/altair/issues/1027)
|
|
df[col_name] = (
|
|
df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "")
|
|
)
|
|
elif str(dtype).startswith("timedelta"):
|
|
raise ValueError(
|
|
'Field "{col_name}" has type "{dtype}" which is '
|
|
"not supported by Altair. Please convert to "
|
|
"either a timestamp or a numerical value."
|
|
"".format(col_name=col_name, dtype=dtype)
|
|
)
|
|
elif str(dtype).startswith("geometry"):
|
|
# geopandas >=0.6.1 uses the dtype geometry. Continue here
|
|
# otherwise it will give an error on np.issubdtype(dtype, np.integer)
|
|
continue
|
|
elif str(dtype) in {
|
|
"Int8",
|
|
"Int16",
|
|
"Int32",
|
|
"Int64",
|
|
"UInt8",
|
|
"UInt16",
|
|
"UInt32",
|
|
"UInt64",
|
|
"Float32",
|
|
"Float64",
|
|
}: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0)
|
|
# https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support
|
|
col = df[col_name].astype(object)
|
|
df[col_name] = col.where(col.notnull(), None)
|
|
elif np.issubdtype(dtype, np.integer):
|
|
# convert integers to objects; np.int is not JSON serializable
|
|
df[col_name] = df[col_name].astype(object)
|
|
elif np.issubdtype(dtype, np.floating):
|
|
# For floats, convert to Python float: np.float is not JSON serializable
|
|
# Also convert NaN/inf values to null, as they are not JSON serializable
|
|
col = df[col_name]
|
|
bad_values = col.isnull() | np.isinf(col)
|
|
df[col_name] = col.astype(object).where(~bad_values, None)
|
|
elif dtype == object:
|
|
# Convert numpy arrays saved as objects to lists
|
|
# Arrays are not JSON serializable
|
|
col = df[col_name].apply(to_list_if_array, convert_dtype=False)
|
|
df[col_name] = col.where(col.notnull(), None)
|
|
return df
|
|
|
|
|
|
def parse_shorthand(
|
|
shorthand,
|
|
data=None,
|
|
parse_aggregates=True,
|
|
parse_window_ops=False,
|
|
parse_timeunits=True,
|
|
parse_types=True,
|
|
):
|
|
"""General tool to parse shorthand values
|
|
|
|
These are of the form:
|
|
|
|
- "col_name"
|
|
- "col_name:O"
|
|
- "average(col_name)"
|
|
- "average(col_name):O"
|
|
|
|
Optionally, a dataframe may be supplied, from which the type
|
|
will be inferred if not specified in the shorthand.
|
|
|
|
Parameters
|
|
----------
|
|
shorthand : dict or string
|
|
The shorthand representation to be parsed
|
|
data : DataFrame, optional
|
|
If specified and of type DataFrame, then use these values to infer the
|
|
column type if not provided by the shorthand.
|
|
parse_aggregates : boolean
|
|
If True (default), then parse aggregate functions within the shorthand.
|
|
parse_window_ops : boolean
|
|
If True then parse window operations within the shorthand (default:False)
|
|
parse_timeunits : boolean
|
|
If True (default), then parse timeUnits from within the shorthand
|
|
parse_types : boolean
|
|
If True (default), then parse typecodes within the shorthand
|
|
|
|
Returns
|
|
-------
|
|
attrs : dict
|
|
a dictionary of attributes extracted from the shorthand
|
|
|
|
Examples
|
|
--------
|
|
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'],
|
|
... 'bar': [1, 2, 3, 4]})
|
|
|
|
>>> parse_shorthand('name') == {'field': 'name'}
|
|
True
|
|
|
|
>>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'}
|
|
True
|
|
|
|
>>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'}
|
|
True
|
|
|
|
>>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'}
|
|
True
|
|
|
|
>>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'}
|
|
True
|
|
|
|
>>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'}
|
|
True
|
|
|
|
>>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'}
|
|
True
|
|
|
|
>>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'}
|
|
True
|
|
|
|
>>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'}
|
|
True
|
|
|
|
>>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'}
|
|
True
|
|
|
|
>>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'}
|
|
True
|
|
|
|
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'}
|
|
True
|
|
"""
|
|
if not shorthand:
|
|
return {}
|
|
|
|
valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP)
|
|
|
|
units = dict(
|
|
field="(?P<field>.*)",
|
|
type="(?P<type>{})".format("|".join(valid_typecodes)),
|
|
agg_count="(?P<aggregate>count)",
|
|
op_count="(?P<op>count)",
|
|
aggregate="(?P<aggregate>{})".format("|".join(AGGREGATES)),
|
|
window_op="(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)),
|
|
timeUnit="(?P<timeUnit>{})".format("|".join(TIMEUNITS)),
|
|
)
|
|
|
|
patterns = []
|
|
|
|
if parse_aggregates:
|
|
patterns.extend([r"{agg_count}\(\)"])
|
|
patterns.extend([r"{aggregate}\({field}\)"])
|
|
if parse_window_ops:
|
|
patterns.extend([r"{op_count}\(\)"])
|
|
patterns.extend([r"{window_op}\({field}\)"])
|
|
if parse_timeunits:
|
|
patterns.extend([r"{timeUnit}\({field}\)"])
|
|
|
|
patterns.extend([r"{field}"])
|
|
|
|
if parse_types:
|
|
patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns)))
|
|
|
|
regexps = (
|
|
re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns
|
|
)
|
|
|
|
# find matches depending on valid fields passed
|
|
if isinstance(shorthand, dict):
|
|
attrs = shorthand
|
|
else:
|
|
attrs = next(
|
|
exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand)
|
|
)
|
|
|
|
# Handle short form of the type expression
|
|
if "type" in attrs:
|
|
attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"])
|
|
|
|
# counts are quantitative by default
|
|
if attrs == {"aggregate": "count"}:
|
|
attrs["type"] = "quantitative"
|
|
|
|
# times are temporal by default
|
|
if "timeUnit" in attrs and "type" not in attrs:
|
|
attrs["type"] = "temporal"
|
|
|
|
# if data is specified and type is not, infer type from data
|
|
if isinstance(data, pd.DataFrame) and "type" not in attrs:
|
|
if "field" in attrs and attrs["field"] in data.columns:
|
|
attrs["type"] = infer_vegalite_type(data[attrs["field"]])
|
|
return attrs
|
|
|
|
|
|
def use_signature(Obj):
|
|
"""Apply call signature and documentation of Obj to the decorated method"""
|
|
|
|
def decorate(f):
|
|
# call-signature of f is exposed via __wrapped__.
|
|
# we want it to mimic Obj.__init__
|
|
f.__wrapped__ = Obj.__init__
|
|
f._uses_signature = Obj
|
|
|
|
# Supplement the docstring of f with information from Obj
|
|
if Obj.__doc__:
|
|
doclines = Obj.__doc__.splitlines()
|
|
if f.__doc__:
|
|
doc = f.__doc__ + "\n".join(doclines[1:])
|
|
else:
|
|
doc = "\n".join(doclines)
|
|
try:
|
|
f.__doc__ = doc
|
|
except AttributeError:
|
|
# __doc__ is not modifiable for classes in Python < 3.3
|
|
pass
|
|
|
|
return f
|
|
|
|
return decorate
|
|
|
|
|
|
def update_subtraits(obj, attrs, **kwargs):
|
|
"""Recursively update sub-traits without overwriting other traits"""
|
|
# TODO: infer keywords from args
|
|
if not kwargs:
|
|
return obj
|
|
|
|
# obj can be a SchemaBase object or a dict
|
|
if obj is Undefined:
|
|
obj = dct = {}
|
|
elif isinstance(obj, SchemaBase):
|
|
dct = obj._kwds
|
|
else:
|
|
dct = obj
|
|
|
|
if isinstance(attrs, str):
|
|
attrs = (attrs,)
|
|
|
|
if len(attrs) == 0:
|
|
dct.update(kwargs)
|
|
else:
|
|
attr = attrs[0]
|
|
trait = dct.get(attr, Undefined)
|
|
if trait is Undefined:
|
|
trait = dct[attr] = {}
|
|
dct[attr] = update_subtraits(trait, attrs[1:], **kwargs)
|
|
return obj
|
|
|
|
|
|
def update_nested(original, update, copy=False):
|
|
"""Update nested dictionaries
|
|
|
|
Parameters
|
|
----------
|
|
original : dict
|
|
the original (nested) dictionary, which will be updated in-place
|
|
update : dict
|
|
the nested dictionary of updates
|
|
copy : bool, default False
|
|
if True, then copy the original dictionary rather than modifying it
|
|
|
|
Returns
|
|
-------
|
|
original : dict
|
|
a reference to the (modified) original dict
|
|
|
|
Examples
|
|
--------
|
|
>>> original = {'x': {'b': 2, 'c': 4}}
|
|
>>> update = {'x': {'b': 5, 'd': 6}, 'y': 40}
|
|
>>> update_nested(original, update) # doctest: +SKIP
|
|
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
|
|
>>> original # doctest: +SKIP
|
|
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
|
|
"""
|
|
if copy:
|
|
original = deepcopy(original)
|
|
for key, val in update.items():
|
|
if isinstance(val, Mapping):
|
|
orig_val = original.get(key, {})
|
|
if isinstance(orig_val, Mapping):
|
|
original[key] = update_nested(orig_val, val)
|
|
else:
|
|
original[key] = val
|
|
else:
|
|
original[key] = val
|
|
return original
|
|
|
|
|
|
def display_traceback(in_ipython=True):
|
|
exc_info = sys.exc_info()
|
|
|
|
if in_ipython:
|
|
from IPython.core.getipython import get_ipython
|
|
|
|
ip = get_ipython()
|
|
else:
|
|
ip = None
|
|
|
|
if ip is not None:
|
|
ip.showtraceback(exc_info)
|
|
else:
|
|
traceback.print_exception(*exc_info)
|
|
|
|
|
|
def infer_encoding_types(args, kwargs, channels):
|
|
"""Infer typed keyword arguments for args and kwargs
|
|
|
|
Parameters
|
|
----------
|
|
args : tuple
|
|
List of function args
|
|
kwargs : dict
|
|
Dict of function kwargs
|
|
channels : module
|
|
The module containing all altair encoding channel classes.
|
|
|
|
Returns
|
|
-------
|
|
kwargs : dict
|
|
All args and kwargs in a single dict, with keys and types
|
|
based on the channels mapping.
|
|
"""
|
|
# Construct a dictionary of channel type to encoding name
|
|
# TODO: cache this somehow?
|
|
channel_objs = (getattr(channels, name) for name in dir(channels))
|
|
channel_objs = (
|
|
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
|
|
)
|
|
channel_to_name = {c: c._encoding_name for c in channel_objs}
|
|
name_to_channel = {}
|
|
for chan, name in channel_to_name.items():
|
|
chans = name_to_channel.setdefault(name, {})
|
|
if chan.__name__.endswith("Datum"):
|
|
key = "datum"
|
|
elif chan.__name__.endswith("Value"):
|
|
key = "value"
|
|
else:
|
|
key = "field"
|
|
chans[key] = chan
|
|
|
|
# First use the mapping to convert args to kwargs based on their types.
|
|
for arg in args:
|
|
if isinstance(arg, (list, tuple)) and len(arg) > 0:
|
|
type_ = type(arg[0])
|
|
else:
|
|
type_ = type(arg)
|
|
|
|
encoding = channel_to_name.get(type_, None)
|
|
if encoding is None:
|
|
raise NotImplementedError("positional of type {}" "".format(type_))
|
|
if encoding in kwargs:
|
|
raise ValueError("encoding {} specified twice.".format(encoding))
|
|
kwargs[encoding] = arg
|
|
|
|
def _wrap_in_channel_class(obj, encoding):
|
|
try:
|
|
condition = obj["condition"]
|
|
except (KeyError, TypeError):
|
|
pass
|
|
else:
|
|
if condition is not Undefined:
|
|
obj = obj.copy()
|
|
obj["condition"] = _wrap_in_channel_class(condition, encoding)
|
|
|
|
if isinstance(obj, SchemaBase):
|
|
return obj
|
|
|
|
if isinstance(obj, str):
|
|
obj = {"shorthand": obj}
|
|
|
|
if isinstance(obj, (list, tuple)):
|
|
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]
|
|
|
|
if encoding not in name_to_channel:
|
|
warnings.warn("Unrecognized encoding channel '{}'".format(encoding))
|
|
return obj
|
|
|
|
classes = name_to_channel[encoding]
|
|
cls = classes["value"] if "value" in obj else classes["field"]
|
|
|
|
try:
|
|
# Don't force validation here; some objects won't be valid until
|
|
# they're created in the context of a chart.
|
|
return cls.from_dict(obj, validate=False)
|
|
except jsonschema.ValidationError:
|
|
# our attempts at finding the correct class have failed
|
|
return obj
|
|
|
|
return {
|
|
encoding: _wrap_in_channel_class(obj, encoding)
|
|
for encoding, obj in kwargs.items()
|
|
}
|