first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from .renderers import Renderer
from .exporter import Exporter

View File

@ -0,0 +1,25 @@
"""
Simple fixes for Python 2/3 compatibility
"""
import sys
PY3K = sys.version_info[0] >= 3
if PY3K:
import builtins
import functools
reduce = functools.reduce
zip = builtins.zip
xrange = builtins.range
map = builtins.map
else:
import __builtin__
import itertools
builtins = __builtin__
reduce = __builtin__.reduce
zip = itertools.izip
xrange = __builtin__.xrange
map = itertools.imap

View File

@ -0,0 +1,310 @@
"""
Matplotlib Exporter
===================
This submodule contains tools for crawling a matplotlib figure and exporting
relevant pieces to a renderer.
"""
import warnings
import io
from . import utils
import matplotlib
from matplotlib import transforms, collections
from matplotlib.backends.backend_agg import FigureCanvasAgg
class Exporter(object):
"""Matplotlib Exporter
Parameters
----------
renderer : Renderer object
The renderer object called by the exporter to create a figure
visualization. See mplexporter.Renderer for information on the
methods which should be defined within the renderer.
close_mpl : bool
If True (default), close the matplotlib figure as it is rendered. This
is useful for when the exporter is used within the notebook, or with
an interactive matplotlib backend.
"""
def __init__(self, renderer, close_mpl=True):
self.close_mpl = close_mpl
self.renderer = renderer
def run(self, fig):
"""
Run the exporter on the given figure
Parmeters
---------
fig : matplotlib.Figure instance
The figure to export
"""
# Calling savefig executes the draw() command, putting elements
# in the correct place.
if fig.canvas is None:
canvas = FigureCanvasAgg(fig)
fig.savefig(io.BytesIO(), format="png", dpi=fig.dpi)
if self.close_mpl:
import matplotlib.pyplot as plt
plt.close(fig)
self.crawl_fig(fig)
@staticmethod
def process_transform(
transform, ax=None, data=None, return_trans=False, force_trans=None
):
"""Process the transform and convert data to figure or data coordinates
Parameters
----------
transform : matplotlib Transform object
The transform applied to the data
ax : matplotlib Axes object (optional)
The axes the data is associated with
data : ndarray (optional)
The array of data to be transformed.
return_trans : bool (optional)
If true, return the final transform of the data
force_trans : matplotlib.transform instance (optional)
If supplied, first force the data to this transform
Returns
-------
code : string
Code is either "data", "axes", "figure", or "display", indicating
the type of coordinates output.
transform : matplotlib transform
the transform used to map input data to output data.
Returned only if return_trans is True
new_data : ndarray
Data transformed to match the given coordinate code.
Returned only if data is specified
"""
if isinstance(transform, transforms.BlendedGenericTransform):
warnings.warn(
"Blended transforms not yet supported. "
"Zoom behavior may not work as expected."
)
if force_trans is not None:
if data is not None:
data = (transform - force_trans).transform(data)
transform = force_trans
code = "display"
if ax is not None:
for (c, trans) in [
("data", ax.transData),
("axes", ax.transAxes),
("figure", ax.figure.transFigure),
("display", transforms.IdentityTransform()),
]:
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break
if data is not None:
if return_trans:
return code, transform.transform(data), transform
else:
return code, transform.transform(data)
else:
if return_trans:
return code, transform
else:
return code
def crawl_fig(self, fig):
"""Crawl the figure and process all axes"""
with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)):
for ax in fig.axes:
self.crawl_ax(ax)
def crawl_ax(self, ax):
"""Crawl the axes and process all elements within"""
with self.renderer.draw_axes(ax=ax, props=utils.get_axes_properties(ax)):
for line in ax.lines:
self.draw_line(ax, line)
for text in ax.texts:
self.draw_text(ax, text)
for (text, ttp) in zip(
[ax.xaxis.label, ax.yaxis.label, ax.title],
["xlabel", "ylabel", "title"],
):
if hasattr(text, "get_text") and text.get_text():
self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp)
for artist in ax.artists:
# TODO: process other artists
if isinstance(artist, matplotlib.text.Text):
self.draw_text(ax, artist)
for patch in ax.patches:
self.draw_patch(ax, patch)
for collection in ax.collections:
self.draw_collection(ax, collection)
for image in ax.images:
self.draw_image(ax, image)
legend = ax.get_legend()
if legend is not None:
props = utils.get_legend_properties(ax, legend)
with self.renderer.draw_legend(legend=legend, props=props):
if props["visible"]:
self.crawl_legend(ax, legend)
def crawl_legend(self, ax, legend):
"""
Recursively look through objects in legend children
"""
legendElements = list(
utils.iter_all_children(legend._legend_box, skipContainers=True)
)
legendElements.append(legend.legendPatch)
for child in legendElements:
# force a large zorder so it appears on top
child.set_zorder(1e6 + child.get_zorder())
# reorder border box to make sure marks are visible
if isinstance(child, matplotlib.patches.FancyBboxPatch):
child.set_zorder(child.get_zorder() - 1)
try:
# What kind of object...
if isinstance(child, matplotlib.patches.Patch):
self.draw_patch(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.text.Text):
if child.get_text() != "None":
self.draw_text(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.lines.Line2D):
self.draw_line(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.collections.Collection):
self.draw_collection(ax, child, force_pathtrans=ax.transAxes)
else:
warnings.warn("Legend element %s not impemented" % child)
except NotImplementedError:
warnings.warn("Legend element %s not impemented" % child)
def draw_line(self, ax, line, force_trans=None):
"""Process a matplotlib line and call renderer.draw_line"""
coordinates, data = self.process_transform(
line.get_transform(), ax, line.get_xydata(), force_trans=force_trans
)
linestyle = utils.get_line_style(line)
if linestyle["dasharray"] is None and linestyle["drawstyle"] == "default":
linestyle = None
markerstyle = utils.get_marker_style(line)
if (
markerstyle["marker"] in ["None", "none", None]
or markerstyle["markerpath"][0].size == 0
):
markerstyle = None
label = line.get_label()
if markerstyle or linestyle:
self.renderer.draw_marked_line(
data=data,
coordinates=coordinates,
linestyle=linestyle,
markerstyle=markerstyle,
label=label,
mplobj=line,
)
def draw_text(self, ax, text, force_trans=None, text_type=None):
"""Process a matplotlib text object and call renderer.draw_text"""
content = text.get_text()
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(
transform, ax, position, force_trans=force_trans
)
style = utils.get_text_style(text)
self.renderer.draw_text(
text=content,
position=position,
coordinates=coords,
text_type=text_type,
style=style,
mplobj=text,
)
def draw_patch(self, ax, patch, force_trans=None):
"""Process a matplotlib patch object and call renderer.draw_path"""
vertices, pathcodes = utils.SVG_path(patch.get_path())
transform = patch.get_transform()
coordinates, vertices = self.process_transform(
transform, ax, vertices, force_trans=force_trans
)
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
self.renderer.draw_path(
data=vertices,
coordinates=coordinates,
pathcodes=pathcodes,
style=linestyle,
mplobj=patch,
)
def draw_collection(
self, ax, collection, force_pathtrans=None, force_offsettrans=None
):
"""Process a matplotlib collection and call renderer.draw_collection"""
(transform, transOffset, offsets, paths) = collection._prepare_points()
offset_coords, offsets = self.process_transform(
transOffset, ax, offsets, force_trans=force_offsettrans
)
path_coords = self.process_transform(transform, ax, force_trans=force_pathtrans)
processed_paths = [utils.SVG_path(path) for path in paths]
processed_paths = [
(
self.process_transform(
transform, ax, path[0], force_trans=force_pathtrans
)[1],
path[1],
)
for path in processed_paths
]
path_transforms = collection.get_transforms()
try:
# matplotlib 1.3: path_transforms are transform objects.
# Convert them to numpy arrays.
path_transforms = [t.get_matrix() for t in path_transforms]
except AttributeError:
# matplotlib 1.4: path transforms are already numpy arrays.
pass
styles = {
"linewidth": collection.get_linewidths(),
"facecolor": collection.get_facecolors(),
"edgecolor": collection.get_edgecolors(),
"alpha": collection._alpha,
"zorder": collection.get_zorder(),
}
offset_dict = {"data": "before", "screen": "after"}
offset_order = offset_dict[collection.get_offset_position()]
self.renderer.draw_path_collection(
paths=processed_paths,
path_coordinates=path_coords,
path_transforms=path_transforms,
offsets=offsets,
offset_coordinates=offset_coords,
offset_order=offset_order,
styles=styles,
mplobj=collection,
)
def draw_image(self, ax, image):
"""Process a matplotlib image object and call renderer.draw_image"""
self.renderer.draw_image(
imdata=utils.image_to_base64(image),
extent=image.get_extent(),
coordinates="data",
style={"alpha": image.get_alpha(), "zorder": image.get_zorder()},
mplobj=image,
)

View File

@ -0,0 +1,12 @@
"""
Matplotlib Renderers
====================
This submodule contains renderer objects which define renderer behavior used
within the Exporter class. The base renderer class is :class:`Renderer`, an
abstract base class
"""
from .base import Renderer
from .vega_renderer import VegaRenderer, fig_to_vega
from .vincent_renderer import VincentRenderer, fig_to_vincent
from .fake_renderer import FakeRenderer, FullFakeRenderer

View File

@ -0,0 +1,429 @@
import warnings
import itertools
from contextlib import contextmanager
from distutils.version import LooseVersion
import numpy as np
import matplotlib as mpl
from matplotlib import transforms
from .. import utils
from .. import _py3k_compat as py3k
class Renderer(object):
@staticmethod
def ax_zoomable(ax):
return bool(ax and ax.get_navigate())
@staticmethod
def ax_has_xgrid(ax):
return bool(ax and ax.xaxis._gridOnMajor and ax.yaxis.get_gridlines())
@staticmethod
def ax_has_ygrid(ax):
return bool(ax and ax.yaxis._gridOnMajor and ax.yaxis.get_gridlines())
@property
def current_ax_zoomable(self):
return self.ax_zoomable(self._current_ax)
@property
def current_ax_has_xgrid(self):
return self.ax_has_xgrid(self._current_ax)
@property
def current_ax_has_ygrid(self):
return self.ax_has_ygrid(self._current_ax)
@contextmanager
def draw_figure(self, fig, props):
if hasattr(self, "_current_fig") and self._current_fig is not None:
warnings.warn("figure embedded in figure: something is wrong")
self._current_fig = fig
self._fig_props = props
self.open_figure(fig=fig, props=props)
yield
self.close_figure(fig=fig)
self._current_fig = None
self._fig_props = {}
@contextmanager
def draw_axes(self, ax, props):
if hasattr(self, "_current_ax") and self._current_ax is not None:
warnings.warn("axes embedded in axes: something is wrong")
self._current_ax = ax
self._ax_props = props
self.open_axes(ax=ax, props=props)
yield
self.close_axes(ax=ax)
self._current_ax = None
self._ax_props = {}
@contextmanager
def draw_legend(self, legend, props):
self._current_legend = legend
self._legend_props = props
self.open_legend(legend=legend, props=props)
yield
self.close_legend(legend=legend)
self._current_legend = None
self._legend_props = {}
# Following are the functions which should be overloaded in subclasses
def open_figure(self, fig, props):
"""
Begin commands for a particular figure.
Parameters
----------
fig : matplotlib.Figure
The Figure which will contain the ensuing axes and elements
props : dictionary
The dictionary of figure properties
"""
pass
def close_figure(self, fig):
"""
Finish commands for a particular figure.
Parameters
----------
fig : matplotlib.Figure
The figure which is finished being drawn.
"""
pass
def open_axes(self, ax, props):
"""
Begin commands for a particular axes.
Parameters
----------
ax : matplotlib.Axes
The Axes which will contain the ensuing axes and elements
props : dictionary
The dictionary of axes properties
"""
pass
def close_axes(self, ax):
"""
Finish commands for a particular axes.
Parameters
----------
ax : matplotlib.Axes
The Axes which is finished being drawn.
"""
pass
def open_legend(self, legend, props):
"""
Beging commands for a particular legend.
Parameters
----------
legend : matplotlib.legend.Legend
The Legend that will contain the ensuing elements
props : dictionary
The dictionary of legend properties
"""
pass
def close_legend(self, legend):
"""
Finish commands for a particular legend.
Parameters
----------
legend : matplotlib.legend.Legend
The Legend which is finished being drawn
"""
pass
def draw_marked_line(
self, data, coordinates, linestyle, markerstyle, label, mplobj=None
):
"""Draw a line that also has markers.
If this isn't reimplemented by a renderer object, by default, it will
make a call to BOTH draw_line and draw_markers when both markerstyle
and linestyle are not None in the same Line2D object.
"""
if linestyle is not None:
self.draw_line(data, coordinates, linestyle, label, mplobj)
if markerstyle is not None:
self.draw_markers(data, coordinates, markerstyle, label, mplobj)
def draw_line(self, data, coordinates, style, label, mplobj=None):
"""
Draw a line. By default, draw the line via the draw_path() command.
Some renderers might wish to override this and provide more
fine-grained behavior.
In matplotlib, lines are generally created via the plt.plot() command,
though this command also can create marker collections.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the line.
mplobj : matplotlib object
the matplotlib plot element which generated this line
"""
pathcodes = ["M"] + (data.shape[0] - 1) * ["L"]
pathstyle = dict(facecolor="none", **style)
pathstyle["edgecolor"] = pathstyle.pop("color")
pathstyle["edgewidth"] = pathstyle.pop("linewidth")
self.draw_path(
data=data,
coordinates=coordinates,
pathcodes=pathcodes,
style=pathstyle,
mplobj=mplobj,
)
@staticmethod
def _iter_path_collection(paths, path_transforms, offsets, styles):
"""Build an iterator over the elements of the path collection"""
N = max(len(paths), len(offsets))
# Before mpl 1.4.0, path_transform can be a false-y value, not a valid
# transformation matrix.
if LooseVersion(mpl.__version__) < LooseVersion("1.4.0"):
if path_transforms is None:
path_transforms = [np.eye(3)]
edgecolor = styles["edgecolor"]
if np.size(edgecolor) == 0:
edgecolor = ["none"]
facecolor = styles["facecolor"]
if np.size(facecolor) == 0:
facecolor = ["none"]
elements = [
paths,
path_transforms,
offsets,
edgecolor,
styles["linewidth"],
facecolor,
]
it = itertools
return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N)
def draw_path_collection(
self,
paths,
path_coordinates,
path_transforms,
offsets,
offset_coordinates,
offset_order,
styles,
mplobj=None,
):
"""
Draw a collection of paths. The paths, offsets, and styles are all
iterables, and the number of paths is max(len(paths), len(offsets)).
By default, this is implemented via multiple calls to the draw_path()
function. For efficiency, Renderers may choose to customize this
implementation.
Examples of path collections created by matplotlib are scatter plots,
histograms, contour plots, and many others.
Parameters
----------
paths : list
list of tuples, where each tuple has two elements:
(data, pathcodes). See draw_path() for a description of these.
path_coordinates: string
the coordinates code for the paths, which should be either
'data' for data coordinates, or 'figure' for figure (pixel)
coordinates.
path_transforms: array_like
an array of shape (*, 3, 3), giving a series of 2D Affine
transforms for the paths. These encode translations, rotations,
and scalings in the standard way.
offsets: array_like
An array of offsets of shape (N, 2)
offset_coordinates : string
the coordinates code for the offsets, which should be either
'data' for data coordinates, or 'figure' for figure (pixel)
coordinates.
offset_order : string
either "before" or "after". This specifies whether the offset
is applied before the path transform, or after. The matplotlib
backend equivalent is "before"->"data", "after"->"screen".
styles: dictionary
A dictionary in which each value is a list of length N, containing
the style(s) for the paths.
mplobj : matplotlib object
the matplotlib plot element which generated this collection
"""
if offset_order == "before":
raise NotImplementedError("offset before transform")
for tup in self._iter_path_collection(paths, path_transforms, offsets, styles):
(path, path_transform, offset, ec, lw, fc) = tup
vertices, pathcodes = path
path_transform = transforms.Affine2D(path_transform)
vertices = path_transform.transform(vertices)
# This is a hack:
if path_coordinates == "figure":
path_coordinates = "points"
style = {
"edgecolor": utils.export_color(ec),
"facecolor": utils.export_color(fc),
"edgewidth": lw,
"dasharray": "10,0",
"alpha": styles["alpha"],
"zorder": styles["zorder"],
}
self.draw_path(
data=vertices,
coordinates=path_coordinates,
pathcodes=pathcodes,
style=style,
offset=offset,
offset_coordinates=offset_coordinates,
mplobj=mplobj,
)
def draw_markers(self, data, coordinates, style, label, mplobj=None):
"""
Draw a set of markers. By default, this is done by repeatedly
calling draw_path(), but renderers should generally overload
this method to provide a more efficient implementation.
In matplotlib, markers are created using the plt.plot() command.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the markers.
mplobj : matplotlib object
the matplotlib plot element which generated this marker collection
"""
vertices, pathcodes = style["markerpath"]
pathstyle = dict(
(key, style[key])
for key in ["alpha", "edgecolor", "facecolor", "zorder", "edgewidth"]
)
pathstyle["dasharray"] = "10,0"
for vertex in data:
self.draw_path(
data=vertices,
coordinates="points",
pathcodes=pathcodes,
style=pathstyle,
offset=vertex,
offset_coordinates=coordinates,
mplobj=mplobj,
)
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
"""
Draw text on the image.
Parameters
----------
text : string
The text to draw
position : tuple
The (x, y) position of the text
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the text.
text_type : string or None
if specified, a type of text such as "xlabel", "ylabel", "title"
mplobj : matplotlib object
the matplotlib plot element which generated this text
"""
raise NotImplementedError()
def draw_path(
self,
data,
coordinates,
pathcodes,
style,
offset=None,
offset_coordinates="data",
mplobj=None,
):
"""
Draw a path.
In matplotlib, paths are created by filled regions, histograms,
contour plots, patches, etc.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
'figure' for figure (pixel) coordinates, or "points" for raw
point coordinates (useful in conjunction with offsets, below).
pathcodes : list
A list of single-character SVG pathcodes associated with the data.
Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't',
'S', 's', 'C', 'c', 'Z', 'z']
See the SVG specification for details. Note that some path codes
consume more than one datapoint (while 'Z' consumes none), so
in general, the length of the pathcodes list will not be the same
as that of the data array.
style : dictionary
a dictionary specifying the appearance of the line.
offset : list (optional)
the (x, y) offset of the path. If not given, no offset will
be used.
offset_coordinates : string (optional)
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
mplobj : matplotlib object
the matplotlib plot element which generated this path
"""
raise NotImplementedError()
def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
"""
Draw an image.
Parameters
----------
imdata : string
base64 encoded png representation of the image
extent : list
the axes extent of the image: [xmin, xmax, ymin, ymax]
coordinates: string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the image
mplobj : matplotlib object
the matplotlib plot object which generated this image
"""
raise NotImplementedError()

View File

@ -0,0 +1,88 @@
from .base import Renderer
class FakeRenderer(Renderer):
"""
Fake Renderer
This is a fake renderer which simply outputs a text tree representing the
elements found in the plot(s). This is used in the unit tests for the
package.
Below are the methods your renderer must implement. You are free to do
anything you wish within the renderer (i.e. build an XML or JSON
representation, call an external API, etc.) Here the renderer just
builds a simple string representation for testing purposes.
"""
def __init__(self):
self.output = ""
def open_figure(self, fig, props):
self.output += "opening figure\n"
def close_figure(self, fig):
self.output += "closing figure\n"
def open_axes(self, ax, props):
self.output += " opening axes\n"
def close_axes(self, ax):
self.output += " closing axes\n"
def open_legend(self, legend, props):
self.output += " opening legend\n"
def close_legend(self, legend):
self.output += " closing legend\n"
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
self.output += " draw text '{0}' {1}\n".format(text, text_type)
def draw_path(
self,
data,
coordinates,
pathcodes,
style,
offset=None,
offset_coordinates="data",
mplobj=None,
):
self.output += " draw path with {0} vertices\n".format(data.shape[0])
def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
self.output += " draw image of size {0}\n".format(len(imdata))
class FullFakeRenderer(FakeRenderer):
"""
Renderer with the full complement of methods.
When the following are left undefined, they will be implemented via
other methods in the class. They can be defined explicitly for
more efficient or specialized use within the renderer implementation.
"""
def draw_line(self, data, coordinates, style, label, mplobj=None):
self.output += " draw line with {0} points\n".format(data.shape[0])
def draw_markers(self, data, coordinates, style, label, mplobj=None):
self.output += " draw {0} markers\n".format(data.shape[0])
def draw_path_collection(
self,
paths,
path_coordinates,
path_transforms,
offsets,
offset_coordinates,
offset_order,
styles,
mplobj=None,
):
self.output += " draw path collection " "with {0} offsets\n".format(
offsets.shape[0]
)

View File

@ -0,0 +1,145 @@
import warnings
import json
import random
from .base import Renderer
from ..exporter import Exporter
class VegaRenderer(Renderer):
def open_figure(self, fig, props):
self.props = props
self.figwidth = int(props["figwidth"] * props["dpi"])
self.figheight = int(props["figheight"] * props["dpi"])
self.data = []
self.scales = []
self.axes = []
self.marks = []
def open_axes(self, ax, props):
if len(self.axes) > 0:
warnings.warn("multiple axes not yet supported")
self.axes = [
dict(type="x", scale="x", ticks=10),
dict(type="y", scale="y", ticks=10),
]
self.scales = [
dict(name="x", domain=props["xlim"], type="linear", range="width",),
dict(name="y", domain=props["ylim"], type="linear", range="height",),
]
def draw_line(self, data, coordinates, style, label, mplobj=None):
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
dataname = "table{0:03d}".format(len(self.data) + 1)
# TODO: respect the other style settings
self.data.append(
{"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
)
self.marks.append(
{
"type": "line",
"from": {"data": dataname},
"properties": {
"enter": {
"interpolate": {"value": "monotone"},
"x": {"scale": "x", "field": "data.x"},
"y": {"scale": "y", "field": "data.y"},
"stroke": {"value": style["color"]},
"strokeOpacity": {"value": style["alpha"]},
"strokeWidth": {"value": style["linewidth"]},
}
},
}
)
def draw_markers(self, data, coordinates, style, label, mplobj=None):
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
dataname = "table{0:03d}".format(len(self.data) + 1)
# TODO: respect the other style settings
self.data.append(
{"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
)
self.marks.append(
{
"type": "symbol",
"from": {"data": dataname},
"properties": {
"enter": {
"interpolate": {"value": "monotone"},
"x": {"scale": "x", "field": "data.x"},
"y": {"scale": "y", "field": "data.y"},
"fill": {"value": style["facecolor"]},
"fillOpacity": {"value": style["alpha"]},
"stroke": {"value": style["edgecolor"]},
"strokeOpacity": {"value": style["alpha"]},
"strokeWidth": {"value": style["edgewidth"]},
}
},
}
)
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
if text_type == "xlabel":
self.axes[0]["title"] = text
elif text_type == "ylabel":
self.axes[1]["title"] = text
class VegaHTML(object):
def __init__(self, renderer):
self.specification = dict(
width=renderer.figwidth,
height=renderer.figheight,
data=renderer.data,
scales=renderer.scales,
axes=renderer.axes,
marks=renderer.marks,
)
def html(self):
"""Build the HTML representation for IPython."""
id = random.randint(0, 2 ** 16)
html = '<div id="vis%d"></div>' % id
html += "<script>\n"
html += VEGA_TEMPLATE % (json.dumps(self.specification), id)
html += "</script>\n"
return html
def _repr_html_(self):
return self.html()
def fig_to_vega(fig, notebook=False):
"""Convert a matplotlib figure to vega dictionary
if notebook=True, then return an object which will display in a notebook
otherwise, return an HTML string.
"""
renderer = VegaRenderer()
Exporter(renderer).run(fig)
vega_html = VegaHTML(renderer)
if notebook:
return vega_html
else:
return vega_html.html()
VEGA_TEMPLATE = """
( function() {
var _do_plot = function() {
if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) {
$([IPython.events]).on("vega_loaded.vincent", _do_plot);
return;
}
vg.parse.spec(%s, function(chart) {
chart({el: "#vis%d"}).update();
});
};
_do_plot();
})();
"""

View File

@ -0,0 +1,54 @@
import warnings
from .base import Renderer
from ..exporter import Exporter
class VincentRenderer(Renderer):
def open_figure(self, fig, props):
self.chart = None
self.figwidth = int(props["figwidth"] * props["dpi"])
self.figheight = int(props["figheight"] * props["dpi"])
def draw_line(self, data, coordinates, style, label, mplobj=None):
import vincent # only import if VincentRenderer is used
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
linedata = {"x": data[:, 0], "y": data[:, 1]}
line = vincent.Line(
linedata, iter_idx="x", width=self.figwidth, height=self.figheight
)
# TODO: respect the other style settings
line.scales["color"].range = [style["color"]]
if self.chart is None:
self.chart = line
else:
warnings.warn("Multiple plot elements not yet supported")
def draw_markers(self, data, coordinates, style, label, mplobj=None):
import vincent # only import if VincentRenderer is used
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
markerdata = {"x": data[:, 0], "y": data[:, 1]}
markers = vincent.Scatter(
markerdata, iter_idx="x", width=self.figwidth, height=self.figheight
)
# TODO: respect the other style settings
markers.scales["color"].range = [style["facecolor"]]
if self.chart is None:
self.chart = markers
else:
warnings.warn("Multiple plot elements not yet supported")
def fig_to_vincent(fig):
"""Convert a matplotlib figure to a vincent object"""
renderer = VincentRenderer()
exporter = Exporter(renderer)
exporter.run(fig)
return renderer.chart

View File

@ -0,0 +1,55 @@
"""
Tools for matplotlib plot exporting
"""
def ipynb_vega_init():
"""Initialize the IPython notebook display elements
This function borrows heavily from the excellent vincent package:
http://github.com/wrobstory/vincent
"""
try:
from IPython.core.display import display, HTML
except ImportError:
print("IPython Notebook could not be loaded.")
require_js = """
if (window['d3'] === undefined) {{
require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }});
require(["d3"], function(d3) {{
window.d3 = d3;
{0}
}});
}};
if (window['topojson'] === undefined) {{
require.config(
{{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }}
);
require(["topojson"], function(topojson) {{
window.topojson = topojson;
}});
}};
"""
d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js"
d3_layout_cloud_js_url = "http://wrobstory.github.io/d3-cloud/" "d3.layout.cloud.js"
topojson_js_url = "http://d3js.org/topojson.v1.min.js"
vega_js_url = "http://trifacta.github.com/vega/vega.js"
dep_libs = """$.getScript("%s", function() {
$.getScript("%s", function() {
$.getScript("%s", function() {
$.getScript("%s", function() {
$([IPython.events]).trigger("vega_loaded.vincent");
})
})
})
});""" % (
d3_geo_projection_js_url,
d3_layout_cloud_js_url,
topojson_js_url,
vega_js_url,
)
load_js = require_js.format(dep_libs)
html = "<script>" + load_js + "</script>"
display(HTML(html))

View File

@ -0,0 +1,382 @@
"""
Utility Routines for Working with Matplotlib Objects
====================================================
"""
import itertools
import io
import base64
import numpy as np
import warnings
import matplotlib
from matplotlib.colors import colorConverter
from matplotlib.path import Path
from matplotlib.markers import MarkerStyle
from matplotlib.transforms import Affine2D
from matplotlib import ticker
def export_color(color):
"""Convert matplotlib color code to hex color or RGBA color"""
if color is None or colorConverter.to_rgba(color)[3] == 0:
return "none"
elif colorConverter.to_rgba(color)[3] == 1:
rgb = colorConverter.to_rgb(color)
return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb))
else:
c = colorConverter.to_rgba(color)
return (
"rgba("
+ ", ".join(str(int(np.round(val * 255))) for val in c[:3])
+ ", "
+ str(c[3])
+ ")"
)
def _many_to_one(input_dict):
"""Convert a many-to-one mapping to a one-to-one mapping"""
return dict((key, val) for keys, val in input_dict.items() for key in keys)
LINESTYLES = _many_to_one(
{
("solid", "-", (None, None)): "none",
("dashed", "--"): "6,6",
("dotted", ":"): "2,2",
("dashdot", "-."): "4,4,2,4",
("", " ", "None", "none"): None,
}
)
def get_dasharray(obj):
"""Get an SVG dash array for the given matplotlib linestyle
Parameters
----------
obj : matplotlib object
The matplotlib line or path object, which must have a get_linestyle()
method which returns a valid matplotlib line code
Returns
-------
dasharray : string
The HTML/SVG dasharray code associated with the object.
"""
if obj.__dict__.get("_dashSeq", None) is not None:
return ",".join(map(str, obj._dashSeq))
else:
ls = obj.get_linestyle()
dasharray = LINESTYLES.get(ls, "not found")
if dasharray == "not found":
warnings.warn(
"line style '{0}' not understood: "
"defaulting to solid line.".format(ls)
)
dasharray = LINESTYLES["solid"]
return dasharray
PATH_DICT = {
Path.LINETO: "L",
Path.MOVETO: "M",
Path.CURVE3: "S",
Path.CURVE4: "C",
Path.CLOSEPOLY: "Z",
}
def SVG_path(path, transform=None, simplify=False):
"""Construct the vertices and SVG codes for the path
Parameters
----------
path : matplotlib.Path object
transform : matplotlib transform (optional)
if specified, the path will be transformed before computing the output.
Returns
-------
vertices : array
The shape (M, 2) array of vertices of the Path. Note that some Path
codes require multiple vertices, so the length of these vertices may
be longer than the list of path codes.
path_codes : list
A length N list of single-character path codes, N <= M. Each code is
a single character, in ['L','M','S','C','Z']. See the standard SVG
path specification for a description of these.
"""
if transform is not None:
path = path.transformed(transform)
vc_tuples = [
(vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code])
for (vertices, path_code) in path.iter_segments(simplify=simplify)
]
if not vc_tuples:
# empty path is a special case
return np.zeros((0, 2)), []
else:
vertices, codes = zip(*vc_tuples)
vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
return vertices, list(codes)
def get_path_style(path, fill=True):
"""Get the style dictionary for matplotlib path objects"""
style = {}
style["alpha"] = path.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["edgecolor"] = export_color(path.get_edgecolor())
if fill:
style["facecolor"] = export_color(path.get_facecolor())
else:
style["facecolor"] = "none"
style["edgewidth"] = path.get_linewidth()
style["dasharray"] = get_dasharray(path)
style["zorder"] = path.get_zorder()
return style
def get_line_style(line):
"""Get the style dictionary for matplotlib line objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["color"] = export_color(line.get_color())
style["linewidth"] = line.get_linewidth()
style["dasharray"] = get_dasharray(line)
style["zorder"] = line.get_zorder()
style["drawstyle"] = line.get_drawstyle()
return style
def get_marker_style(line):
"""Get the style dictionary for matplotlib marker objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["facecolor"] = export_color(line.get_markerfacecolor())
style["edgecolor"] = export_color(line.get_markeredgecolor())
style["edgewidth"] = line.get_markeredgewidth()
style["marker"] = line.get_marker()
markerstyle = MarkerStyle(line.get_marker())
markersize = line.get_markersize()
markertransform = markerstyle.get_transform() + Affine2D().scale(
markersize, -markersize
)
style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform)
style["markersize"] = markersize
style["zorder"] = line.get_zorder()
return style
def get_text_style(text):
"""Return the text style dict for a text instance"""
style = {}
style["alpha"] = text.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["fontsize"] = text.get_size()
style["color"] = export_color(text.get_color())
style["halign"] = text.get_horizontalalignment() # left, center, right
style["valign"] = text.get_verticalalignment() # baseline, center, top
style["malign"] = text._multialignment # text alignment when '\n' in text
style["rotation"] = text.get_rotation()
style["zorder"] = text.get_zorder()
return style
def get_axis_properties(axis):
"""Return the property dictionary for a matplotlib.Axis instance"""
props = {}
label1On = axis._major_tick_kw.get("label1On", True)
if isinstance(axis, matplotlib.axis.XAxis):
if label1On:
props["position"] = "bottom"
else:
props["position"] = "top"
elif isinstance(axis, matplotlib.axis.YAxis):
if label1On:
props["position"] = "left"
else:
props["position"] = "right"
else:
raise ValueError("{0} should be an Axis instance".format(axis))
# Use tick values if appropriate
locator = axis.get_major_locator()
props["nticks"] = len(locator())
if isinstance(locator, ticker.FixedLocator):
props["tickvalues"] = list(locator())
else:
props["tickvalues"] = None
# Find tick formats
formatter = axis.get_major_formatter()
if isinstance(formatter, ticker.NullFormatter):
props["tickformat"] = ""
elif isinstance(formatter, ticker.FixedFormatter):
props["tickformat"] = list(formatter.seq)
elif isinstance(formatter, ticker.FuncFormatter):
props["tickformat"] = list(formatter.func.args[0].values())
elif not any(label.get_visible() for label in axis.get_ticklabels()):
props["tickformat"] = ""
else:
props["tickformat"] = None
# Get axis scale
props["scale"] = axis.get_scale()
# Get major tick label size (assumes that's all we really care about!)
labels = axis.get_ticklabels()
if labels:
props["fontsize"] = labels[0].get_fontsize()
else:
props["fontsize"] = None
# Get associated grid
props["grid"] = get_grid_style(axis)
# get axis visibility
props["visible"] = axis.get_visible()
return props
def get_grid_style(axis):
gridlines = axis.get_gridlines()
if axis._major_tick_kw["gridOn"] and len(gridlines) > 0:
color = export_color(gridlines[0].get_color())
alpha = gridlines[0].get_alpha()
dasharray = get_dasharray(gridlines[0])
return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
else:
return {"gridOn": False}
def get_figure_properties(fig):
return {
"figwidth": fig.get_figwidth(),
"figheight": fig.get_figheight(),
"dpi": fig.dpi,
}
def get_axes_properties(ax):
props = {
"axesbg": export_color(ax.patch.get_facecolor()),
"axesbgalpha": ax.patch.get_alpha(),
"bounds": ax.get_position().bounds,
"dynamic": ax.get_navigate(),
"axison": ax.axison,
"frame_on": ax.get_frame_on(),
"patch_visible": ax.patch.get_visible(),
"axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
}
for axname in ["x", "y"]:
axis = getattr(ax, axname + "axis")
domain = getattr(ax, "get_{0}lim".format(axname))()
lim = domain
if isinstance(axis.converter, matplotlib.dates.DateConverter):
scale = "date"
try:
import pandas as pd
from pandas.tseries.converter import PeriodConverter
except ImportError:
pd = None
if pd is not None and isinstance(axis.converter, PeriodConverter):
_dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain]
domain = [
(d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0)
for d in _dates
]
else:
domain = [
(
d.year,
d.month - 1,
d.day,
d.hour,
d.minute,
d.second,
d.microsecond * 1e-3,
)
for d in matplotlib.dates.num2date(domain)
]
else:
scale = axis.get_scale()
if scale not in ["date", "linear", "log"]:
raise ValueError("Unknown axis scale: " "{0}".format(axis.get_scale()))
props[axname + "scale"] = scale
props[axname + "lim"] = lim
props[axname + "domain"] = domain
return props
def iter_all_children(obj, skipContainers=False):
"""
Returns an iterator over all childen and nested children using
obj's get_children() method
if skipContainers is true, only childless objects are returned.
"""
if hasattr(obj, "get_children") and len(obj.get_children()) > 0:
for child in obj.get_children():
if not skipContainers:
yield child
# could use `yield from` in python 3...
for grandchild in iter_all_children(child, skipContainers):
yield grandchild
else:
yield obj
def get_legend_properties(ax, legend):
handles, labels = ax.get_legend_handles_labels()
visible = legend.get_visible()
return {"handles": handles, "labels": labels, "visible": visible}
def image_to_base64(image):
"""
Convert a matplotlib image to a base64 png representation
Parameters
----------
image : matplotlib image object
The image to be converted.
Returns
-------
image_base64 : string
The UTF8-encoded base64 string representation of the png image.
"""
ax = image.axes
binary_buffer = io.BytesIO()
# image is saved in axes coordinates: we need to temporarily
# set the correct limits to get the correct image
lim = ax.axis()
ax.axis(image.get_extent())
image.write_png(binary_buffer)
ax.axis(lim)
binary_buffer.seek(0)
return base64.b64encode(binary_buffer.read()).decode("utf-8")