first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
"""
matplotlylib
============
This module converts matplotlib figure objects into JSON structures which can
be understood and visualized by Plotly.
Most of the functionality should be accessed through the parent directory's
'tools' module or 'plotly' package.
"""
from __future__ import absolute_import
from plotly.matplotlylib.renderer import PlotlyRenderer
from plotly.matplotlylib.mplexporter import Exporter

View File

@@ -0,0 +1,2 @@
from .renderers import Renderer
from .exporter import Exporter

View File

@@ -0,0 +1,25 @@
"""
Simple fixes for Python 2/3 compatibility
"""
import sys
PY3K = sys.version_info[0] >= 3
if PY3K:
import builtins
import functools
reduce = functools.reduce
zip = builtins.zip
xrange = builtins.range
map = builtins.map
else:
import __builtin__
import itertools
builtins = __builtin__
reduce = __builtin__.reduce
zip = itertools.izip
xrange = __builtin__.xrange
map = itertools.imap

View File

@@ -0,0 +1,310 @@
"""
Matplotlib Exporter
===================
This submodule contains tools for crawling a matplotlib figure and exporting
relevant pieces to a renderer.
"""
import warnings
import io
from . import utils
import matplotlib
from matplotlib import transforms, collections
from matplotlib.backends.backend_agg import FigureCanvasAgg
class Exporter(object):
"""Matplotlib Exporter
Parameters
----------
renderer : Renderer object
The renderer object called by the exporter to create a figure
visualization. See mplexporter.Renderer for information on the
methods which should be defined within the renderer.
close_mpl : bool
If True (default), close the matplotlib figure as it is rendered. This
is useful for when the exporter is used within the notebook, or with
an interactive matplotlib backend.
"""
def __init__(self, renderer, close_mpl=True):
self.close_mpl = close_mpl
self.renderer = renderer
def run(self, fig):
"""
Run the exporter on the given figure
Parmeters
---------
fig : matplotlib.Figure instance
The figure to export
"""
# Calling savefig executes the draw() command, putting elements
# in the correct place.
if fig.canvas is None:
canvas = FigureCanvasAgg(fig)
fig.savefig(io.BytesIO(), format="png", dpi=fig.dpi)
if self.close_mpl:
import matplotlib.pyplot as plt
plt.close(fig)
self.crawl_fig(fig)
@staticmethod
def process_transform(
transform, ax=None, data=None, return_trans=False, force_trans=None
):
"""Process the transform and convert data to figure or data coordinates
Parameters
----------
transform : matplotlib Transform object
The transform applied to the data
ax : matplotlib Axes object (optional)
The axes the data is associated with
data : ndarray (optional)
The array of data to be transformed.
return_trans : bool (optional)
If true, return the final transform of the data
force_trans : matplotlib.transform instance (optional)
If supplied, first force the data to this transform
Returns
-------
code : string
Code is either "data", "axes", "figure", or "display", indicating
the type of coordinates output.
transform : matplotlib transform
the transform used to map input data to output data.
Returned only if return_trans is True
new_data : ndarray
Data transformed to match the given coordinate code.
Returned only if data is specified
"""
if isinstance(transform, transforms.BlendedGenericTransform):
warnings.warn(
"Blended transforms not yet supported. "
"Zoom behavior may not work as expected."
)
if force_trans is not None:
if data is not None:
data = (transform - force_trans).transform(data)
transform = force_trans
code = "display"
if ax is not None:
for (c, trans) in [
("data", ax.transData),
("axes", ax.transAxes),
("figure", ax.figure.transFigure),
("display", transforms.IdentityTransform()),
]:
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break
if data is not None:
if return_trans:
return code, transform.transform(data), transform
else:
return code, transform.transform(data)
else:
if return_trans:
return code, transform
else:
return code
def crawl_fig(self, fig):
"""Crawl the figure and process all axes"""
with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)):
for ax in fig.axes:
self.crawl_ax(ax)
def crawl_ax(self, ax):
"""Crawl the axes and process all elements within"""
with self.renderer.draw_axes(ax=ax, props=utils.get_axes_properties(ax)):
for line in ax.lines:
self.draw_line(ax, line)
for text in ax.texts:
self.draw_text(ax, text)
for (text, ttp) in zip(
[ax.xaxis.label, ax.yaxis.label, ax.title],
["xlabel", "ylabel", "title"],
):
if hasattr(text, "get_text") and text.get_text():
self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp)
for artist in ax.artists:
# TODO: process other artists
if isinstance(artist, matplotlib.text.Text):
self.draw_text(ax, artist)
for patch in ax.patches:
self.draw_patch(ax, patch)
for collection in ax.collections:
self.draw_collection(ax, collection)
for image in ax.images:
self.draw_image(ax, image)
legend = ax.get_legend()
if legend is not None:
props = utils.get_legend_properties(ax, legend)
with self.renderer.draw_legend(legend=legend, props=props):
if props["visible"]:
self.crawl_legend(ax, legend)
def crawl_legend(self, ax, legend):
"""
Recursively look through objects in legend children
"""
legendElements = list(
utils.iter_all_children(legend._legend_box, skipContainers=True)
)
legendElements.append(legend.legendPatch)
for child in legendElements:
# force a large zorder so it appears on top
child.set_zorder(1e6 + child.get_zorder())
# reorder border box to make sure marks are visible
if isinstance(child, matplotlib.patches.FancyBboxPatch):
child.set_zorder(child.get_zorder() - 1)
try:
# What kind of object...
if isinstance(child, matplotlib.patches.Patch):
self.draw_patch(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.text.Text):
if child.get_text() != "None":
self.draw_text(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.lines.Line2D):
self.draw_line(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.collections.Collection):
self.draw_collection(ax, child, force_pathtrans=ax.transAxes)
else:
warnings.warn("Legend element %s not impemented" % child)
except NotImplementedError:
warnings.warn("Legend element %s not impemented" % child)
def draw_line(self, ax, line, force_trans=None):
"""Process a matplotlib line and call renderer.draw_line"""
coordinates, data = self.process_transform(
line.get_transform(), ax, line.get_xydata(), force_trans=force_trans
)
linestyle = utils.get_line_style(line)
if linestyle["dasharray"] is None and linestyle["drawstyle"] == "default":
linestyle = None
markerstyle = utils.get_marker_style(line)
if (
markerstyle["marker"] in ["None", "none", None]
or markerstyle["markerpath"][0].size == 0
):
markerstyle = None
label = line.get_label()
if markerstyle or linestyle:
self.renderer.draw_marked_line(
data=data,
coordinates=coordinates,
linestyle=linestyle,
markerstyle=markerstyle,
label=label,
mplobj=line,
)
def draw_text(self, ax, text, force_trans=None, text_type=None):
"""Process a matplotlib text object and call renderer.draw_text"""
content = text.get_text()
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(
transform, ax, position, force_trans=force_trans
)
style = utils.get_text_style(text)
self.renderer.draw_text(
text=content,
position=position,
coordinates=coords,
text_type=text_type,
style=style,
mplobj=text,
)
def draw_patch(self, ax, patch, force_trans=None):
"""Process a matplotlib patch object and call renderer.draw_path"""
vertices, pathcodes = utils.SVG_path(patch.get_path())
transform = patch.get_transform()
coordinates, vertices = self.process_transform(
transform, ax, vertices, force_trans=force_trans
)
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
self.renderer.draw_path(
data=vertices,
coordinates=coordinates,
pathcodes=pathcodes,
style=linestyle,
mplobj=patch,
)
def draw_collection(
self, ax, collection, force_pathtrans=None, force_offsettrans=None
):
"""Process a matplotlib collection and call renderer.draw_collection"""
(transform, transOffset, offsets, paths) = collection._prepare_points()
offset_coords, offsets = self.process_transform(
transOffset, ax, offsets, force_trans=force_offsettrans
)
path_coords = self.process_transform(transform, ax, force_trans=force_pathtrans)
processed_paths = [utils.SVG_path(path) for path in paths]
processed_paths = [
(
self.process_transform(
transform, ax, path[0], force_trans=force_pathtrans
)[1],
path[1],
)
for path in processed_paths
]
path_transforms = collection.get_transforms()
try:
# matplotlib 1.3: path_transforms are transform objects.
# Convert them to numpy arrays.
path_transforms = [t.get_matrix() for t in path_transforms]
except AttributeError:
# matplotlib 1.4: path transforms are already numpy arrays.
pass
styles = {
"linewidth": collection.get_linewidths(),
"facecolor": collection.get_facecolors(),
"edgecolor": collection.get_edgecolors(),
"alpha": collection._alpha,
"zorder": collection.get_zorder(),
}
offset_dict = {"data": "before", "screen": "after"}
offset_order = offset_dict[collection.get_offset_position()]
self.renderer.draw_path_collection(
paths=processed_paths,
path_coordinates=path_coords,
path_transforms=path_transforms,
offsets=offsets,
offset_coordinates=offset_coords,
offset_order=offset_order,
styles=styles,
mplobj=collection,
)
def draw_image(self, ax, image):
"""Process a matplotlib image object and call renderer.draw_image"""
self.renderer.draw_image(
imdata=utils.image_to_base64(image),
extent=image.get_extent(),
coordinates="data",
style={"alpha": image.get_alpha(), "zorder": image.get_zorder()},
mplobj=image,
)

View File

@@ -0,0 +1,12 @@
"""
Matplotlib Renderers
====================
This submodule contains renderer objects which define renderer behavior used
within the Exporter class. The base renderer class is :class:`Renderer`, an
abstract base class
"""
from .base import Renderer
from .vega_renderer import VegaRenderer, fig_to_vega
from .vincent_renderer import VincentRenderer, fig_to_vincent
from .fake_renderer import FakeRenderer, FullFakeRenderer

View File

@@ -0,0 +1,429 @@
import warnings
import itertools
from contextlib import contextmanager
from distutils.version import LooseVersion
import numpy as np
import matplotlib as mpl
from matplotlib import transforms
from .. import utils
from .. import _py3k_compat as py3k
class Renderer(object):
@staticmethod
def ax_zoomable(ax):
return bool(ax and ax.get_navigate())
@staticmethod
def ax_has_xgrid(ax):
return bool(ax and ax.xaxis._gridOnMajor and ax.yaxis.get_gridlines())
@staticmethod
def ax_has_ygrid(ax):
return bool(ax and ax.yaxis._gridOnMajor and ax.yaxis.get_gridlines())
@property
def current_ax_zoomable(self):
return self.ax_zoomable(self._current_ax)
@property
def current_ax_has_xgrid(self):
return self.ax_has_xgrid(self._current_ax)
@property
def current_ax_has_ygrid(self):
return self.ax_has_ygrid(self._current_ax)
@contextmanager
def draw_figure(self, fig, props):
if hasattr(self, "_current_fig") and self._current_fig is not None:
warnings.warn("figure embedded in figure: something is wrong")
self._current_fig = fig
self._fig_props = props
self.open_figure(fig=fig, props=props)
yield
self.close_figure(fig=fig)
self._current_fig = None
self._fig_props = {}
@contextmanager
def draw_axes(self, ax, props):
if hasattr(self, "_current_ax") and self._current_ax is not None:
warnings.warn("axes embedded in axes: something is wrong")
self._current_ax = ax
self._ax_props = props
self.open_axes(ax=ax, props=props)
yield
self.close_axes(ax=ax)
self._current_ax = None
self._ax_props = {}
@contextmanager
def draw_legend(self, legend, props):
self._current_legend = legend
self._legend_props = props
self.open_legend(legend=legend, props=props)
yield
self.close_legend(legend=legend)
self._current_legend = None
self._legend_props = {}
# Following are the functions which should be overloaded in subclasses
def open_figure(self, fig, props):
"""
Begin commands for a particular figure.
Parameters
----------
fig : matplotlib.Figure
The Figure which will contain the ensuing axes and elements
props : dictionary
The dictionary of figure properties
"""
pass
def close_figure(self, fig):
"""
Finish commands for a particular figure.
Parameters
----------
fig : matplotlib.Figure
The figure which is finished being drawn.
"""
pass
def open_axes(self, ax, props):
"""
Begin commands for a particular axes.
Parameters
----------
ax : matplotlib.Axes
The Axes which will contain the ensuing axes and elements
props : dictionary
The dictionary of axes properties
"""
pass
def close_axes(self, ax):
"""
Finish commands for a particular axes.
Parameters
----------
ax : matplotlib.Axes
The Axes which is finished being drawn.
"""
pass
def open_legend(self, legend, props):
"""
Beging commands for a particular legend.
Parameters
----------
legend : matplotlib.legend.Legend
The Legend that will contain the ensuing elements
props : dictionary
The dictionary of legend properties
"""
pass
def close_legend(self, legend):
"""
Finish commands for a particular legend.
Parameters
----------
legend : matplotlib.legend.Legend
The Legend which is finished being drawn
"""
pass
def draw_marked_line(
self, data, coordinates, linestyle, markerstyle, label, mplobj=None
):
"""Draw a line that also has markers.
If this isn't reimplemented by a renderer object, by default, it will
make a call to BOTH draw_line and draw_markers when both markerstyle
and linestyle are not None in the same Line2D object.
"""
if linestyle is not None:
self.draw_line(data, coordinates, linestyle, label, mplobj)
if markerstyle is not None:
self.draw_markers(data, coordinates, markerstyle, label, mplobj)
def draw_line(self, data, coordinates, style, label, mplobj=None):
"""
Draw a line. By default, draw the line via the draw_path() command.
Some renderers might wish to override this and provide more
fine-grained behavior.
In matplotlib, lines are generally created via the plt.plot() command,
though this command also can create marker collections.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the line.
mplobj : matplotlib object
the matplotlib plot element which generated this line
"""
pathcodes = ["M"] + (data.shape[0] - 1) * ["L"]
pathstyle = dict(facecolor="none", **style)
pathstyle["edgecolor"] = pathstyle.pop("color")
pathstyle["edgewidth"] = pathstyle.pop("linewidth")
self.draw_path(
data=data,
coordinates=coordinates,
pathcodes=pathcodes,
style=pathstyle,
mplobj=mplobj,
)
@staticmethod
def _iter_path_collection(paths, path_transforms, offsets, styles):
"""Build an iterator over the elements of the path collection"""
N = max(len(paths), len(offsets))
# Before mpl 1.4.0, path_transform can be a false-y value, not a valid
# transformation matrix.
if LooseVersion(mpl.__version__) < LooseVersion("1.4.0"):
if path_transforms is None:
path_transforms = [np.eye(3)]
edgecolor = styles["edgecolor"]
if np.size(edgecolor) == 0:
edgecolor = ["none"]
facecolor = styles["facecolor"]
if np.size(facecolor) == 0:
facecolor = ["none"]
elements = [
paths,
path_transforms,
offsets,
edgecolor,
styles["linewidth"],
facecolor,
]
it = itertools
return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N)
def draw_path_collection(
self,
paths,
path_coordinates,
path_transforms,
offsets,
offset_coordinates,
offset_order,
styles,
mplobj=None,
):
"""
Draw a collection of paths. The paths, offsets, and styles are all
iterables, and the number of paths is max(len(paths), len(offsets)).
By default, this is implemented via multiple calls to the draw_path()
function. For efficiency, Renderers may choose to customize this
implementation.
Examples of path collections created by matplotlib are scatter plots,
histograms, contour plots, and many others.
Parameters
----------
paths : list
list of tuples, where each tuple has two elements:
(data, pathcodes). See draw_path() for a description of these.
path_coordinates: string
the coordinates code for the paths, which should be either
'data' for data coordinates, or 'figure' for figure (pixel)
coordinates.
path_transforms: array_like
an array of shape (*, 3, 3), giving a series of 2D Affine
transforms for the paths. These encode translations, rotations,
and scalings in the standard way.
offsets: array_like
An array of offsets of shape (N, 2)
offset_coordinates : string
the coordinates code for the offsets, which should be either
'data' for data coordinates, or 'figure' for figure (pixel)
coordinates.
offset_order : string
either "before" or "after". This specifies whether the offset
is applied before the path transform, or after. The matplotlib
backend equivalent is "before"->"data", "after"->"screen".
styles: dictionary
A dictionary in which each value is a list of length N, containing
the style(s) for the paths.
mplobj : matplotlib object
the matplotlib plot element which generated this collection
"""
if offset_order == "before":
raise NotImplementedError("offset before transform")
for tup in self._iter_path_collection(paths, path_transforms, offsets, styles):
(path, path_transform, offset, ec, lw, fc) = tup
vertices, pathcodes = path
path_transform = transforms.Affine2D(path_transform)
vertices = path_transform.transform(vertices)
# This is a hack:
if path_coordinates == "figure":
path_coordinates = "points"
style = {
"edgecolor": utils.export_color(ec),
"facecolor": utils.export_color(fc),
"edgewidth": lw,
"dasharray": "10,0",
"alpha": styles["alpha"],
"zorder": styles["zorder"],
}
self.draw_path(
data=vertices,
coordinates=path_coordinates,
pathcodes=pathcodes,
style=style,
offset=offset,
offset_coordinates=offset_coordinates,
mplobj=mplobj,
)
def draw_markers(self, data, coordinates, style, label, mplobj=None):
"""
Draw a set of markers. By default, this is done by repeatedly
calling draw_path(), but renderers should generally overload
this method to provide a more efficient implementation.
In matplotlib, markers are created using the plt.plot() command.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the markers.
mplobj : matplotlib object
the matplotlib plot element which generated this marker collection
"""
vertices, pathcodes = style["markerpath"]
pathstyle = dict(
(key, style[key])
for key in ["alpha", "edgecolor", "facecolor", "zorder", "edgewidth"]
)
pathstyle["dasharray"] = "10,0"
for vertex in data:
self.draw_path(
data=vertices,
coordinates="points",
pathcodes=pathcodes,
style=pathstyle,
offset=vertex,
offset_coordinates=coordinates,
mplobj=mplobj,
)
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
"""
Draw text on the image.
Parameters
----------
text : string
The text to draw
position : tuple
The (x, y) position of the text
coordinates : string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the text.
text_type : string or None
if specified, a type of text such as "xlabel", "ylabel", "title"
mplobj : matplotlib object
the matplotlib plot element which generated this text
"""
raise NotImplementedError()
def draw_path(
self,
data,
coordinates,
pathcodes,
style,
offset=None,
offset_coordinates="data",
mplobj=None,
):
"""
Draw a path.
In matplotlib, paths are created by filled regions, histograms,
contour plots, patches, etc.
Parameters
----------
data : array_like
A shape (N, 2) array of datapoints.
coordinates : string
A string code, which should be either 'data' for data coordinates,
'figure' for figure (pixel) coordinates, or "points" for raw
point coordinates (useful in conjunction with offsets, below).
pathcodes : list
A list of single-character SVG pathcodes associated with the data.
Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't',
'S', 's', 'C', 'c', 'Z', 'z']
See the SVG specification for details. Note that some path codes
consume more than one datapoint (while 'Z' consumes none), so
in general, the length of the pathcodes list will not be the same
as that of the data array.
style : dictionary
a dictionary specifying the appearance of the line.
offset : list (optional)
the (x, y) offset of the path. If not given, no offset will
be used.
offset_coordinates : string (optional)
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
mplobj : matplotlib object
the matplotlib plot element which generated this path
"""
raise NotImplementedError()
def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
"""
Draw an image.
Parameters
----------
imdata : string
base64 encoded png representation of the image
extent : list
the axes extent of the image: [xmin, xmax, ymin, ymax]
coordinates: string
A string code, which should be either 'data' for data coordinates,
or 'figure' for figure (pixel) coordinates.
style : dictionary
a dictionary specifying the appearance of the image
mplobj : matplotlib object
the matplotlib plot object which generated this image
"""
raise NotImplementedError()

View File

@@ -0,0 +1,88 @@
from .base import Renderer
class FakeRenderer(Renderer):
"""
Fake Renderer
This is a fake renderer which simply outputs a text tree representing the
elements found in the plot(s). This is used in the unit tests for the
package.
Below are the methods your renderer must implement. You are free to do
anything you wish within the renderer (i.e. build an XML or JSON
representation, call an external API, etc.) Here the renderer just
builds a simple string representation for testing purposes.
"""
def __init__(self):
self.output = ""
def open_figure(self, fig, props):
self.output += "opening figure\n"
def close_figure(self, fig):
self.output += "closing figure\n"
def open_axes(self, ax, props):
self.output += " opening axes\n"
def close_axes(self, ax):
self.output += " closing axes\n"
def open_legend(self, legend, props):
self.output += " opening legend\n"
def close_legend(self, legend):
self.output += " closing legend\n"
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
self.output += " draw text '{0}' {1}\n".format(text, text_type)
def draw_path(
self,
data,
coordinates,
pathcodes,
style,
offset=None,
offset_coordinates="data",
mplobj=None,
):
self.output += " draw path with {0} vertices\n".format(data.shape[0])
def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
self.output += " draw image of size {0}\n".format(len(imdata))
class FullFakeRenderer(FakeRenderer):
"""
Renderer with the full complement of methods.
When the following are left undefined, they will be implemented via
other methods in the class. They can be defined explicitly for
more efficient or specialized use within the renderer implementation.
"""
def draw_line(self, data, coordinates, style, label, mplobj=None):
self.output += " draw line with {0} points\n".format(data.shape[0])
def draw_markers(self, data, coordinates, style, label, mplobj=None):
self.output += " draw {0} markers\n".format(data.shape[0])
def draw_path_collection(
self,
paths,
path_coordinates,
path_transforms,
offsets,
offset_coordinates,
offset_order,
styles,
mplobj=None,
):
self.output += " draw path collection " "with {0} offsets\n".format(
offsets.shape[0]
)

View File

@@ -0,0 +1,145 @@
import warnings
import json
import random
from .base import Renderer
from ..exporter import Exporter
class VegaRenderer(Renderer):
def open_figure(self, fig, props):
self.props = props
self.figwidth = int(props["figwidth"] * props["dpi"])
self.figheight = int(props["figheight"] * props["dpi"])
self.data = []
self.scales = []
self.axes = []
self.marks = []
def open_axes(self, ax, props):
if len(self.axes) > 0:
warnings.warn("multiple axes not yet supported")
self.axes = [
dict(type="x", scale="x", ticks=10),
dict(type="y", scale="y", ticks=10),
]
self.scales = [
dict(name="x", domain=props["xlim"], type="linear", range="width",),
dict(name="y", domain=props["ylim"], type="linear", range="height",),
]
def draw_line(self, data, coordinates, style, label, mplobj=None):
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
dataname = "table{0:03d}".format(len(self.data) + 1)
# TODO: respect the other style settings
self.data.append(
{"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
)
self.marks.append(
{
"type": "line",
"from": {"data": dataname},
"properties": {
"enter": {
"interpolate": {"value": "monotone"},
"x": {"scale": "x", "field": "data.x"},
"y": {"scale": "y", "field": "data.y"},
"stroke": {"value": style["color"]},
"strokeOpacity": {"value": style["alpha"]},
"strokeWidth": {"value": style["linewidth"]},
}
},
}
)
def draw_markers(self, data, coordinates, style, label, mplobj=None):
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
dataname = "table{0:03d}".format(len(self.data) + 1)
# TODO: respect the other style settings
self.data.append(
{"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
)
self.marks.append(
{
"type": "symbol",
"from": {"data": dataname},
"properties": {
"enter": {
"interpolate": {"value": "monotone"},
"x": {"scale": "x", "field": "data.x"},
"y": {"scale": "y", "field": "data.y"},
"fill": {"value": style["facecolor"]},
"fillOpacity": {"value": style["alpha"]},
"stroke": {"value": style["edgecolor"]},
"strokeOpacity": {"value": style["alpha"]},
"strokeWidth": {"value": style["edgewidth"]},
}
},
}
)
def draw_text(
self, text, position, coordinates, style, text_type=None, mplobj=None
):
if text_type == "xlabel":
self.axes[0]["title"] = text
elif text_type == "ylabel":
self.axes[1]["title"] = text
class VegaHTML(object):
def __init__(self, renderer):
self.specification = dict(
width=renderer.figwidth,
height=renderer.figheight,
data=renderer.data,
scales=renderer.scales,
axes=renderer.axes,
marks=renderer.marks,
)
def html(self):
"""Build the HTML representation for IPython."""
id = random.randint(0, 2 ** 16)
html = '<div id="vis%d"></div>' % id
html += "<script>\n"
html += VEGA_TEMPLATE % (json.dumps(self.specification), id)
html += "</script>\n"
return html
def _repr_html_(self):
return self.html()
def fig_to_vega(fig, notebook=False):
"""Convert a matplotlib figure to vega dictionary
if notebook=True, then return an object which will display in a notebook
otherwise, return an HTML string.
"""
renderer = VegaRenderer()
Exporter(renderer).run(fig)
vega_html = VegaHTML(renderer)
if notebook:
return vega_html
else:
return vega_html.html()
VEGA_TEMPLATE = """
( function() {
var _do_plot = function() {
if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) {
$([IPython.events]).on("vega_loaded.vincent", _do_plot);
return;
}
vg.parse.spec(%s, function(chart) {
chart({el: "#vis%d"}).update();
});
};
_do_plot();
})();
"""

View File

@@ -0,0 +1,54 @@
import warnings
from .base import Renderer
from ..exporter import Exporter
class VincentRenderer(Renderer):
def open_figure(self, fig, props):
self.chart = None
self.figwidth = int(props["figwidth"] * props["dpi"])
self.figheight = int(props["figheight"] * props["dpi"])
def draw_line(self, data, coordinates, style, label, mplobj=None):
import vincent # only import if VincentRenderer is used
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
linedata = {"x": data[:, 0], "y": data[:, 1]}
line = vincent.Line(
linedata, iter_idx="x", width=self.figwidth, height=self.figheight
)
# TODO: respect the other style settings
line.scales["color"].range = [style["color"]]
if self.chart is None:
self.chart = line
else:
warnings.warn("Multiple plot elements not yet supported")
def draw_markers(self, data, coordinates, style, label, mplobj=None):
import vincent # only import if VincentRenderer is used
if coordinates != "data":
warnings.warn("Only data coordinates supported. Skipping this")
markerdata = {"x": data[:, 0], "y": data[:, 1]}
markers = vincent.Scatter(
markerdata, iter_idx="x", width=self.figwidth, height=self.figheight
)
# TODO: respect the other style settings
markers.scales["color"].range = [style["facecolor"]]
if self.chart is None:
self.chart = markers
else:
warnings.warn("Multiple plot elements not yet supported")
def fig_to_vincent(fig):
"""Convert a matplotlib figure to a vincent object"""
renderer = VincentRenderer()
exporter = Exporter(renderer)
exporter.run(fig)
return renderer.chart

View File

@@ -0,0 +1,55 @@
"""
Tools for matplotlib plot exporting
"""
def ipynb_vega_init():
"""Initialize the IPython notebook display elements
This function borrows heavily from the excellent vincent package:
http://github.com/wrobstory/vincent
"""
try:
from IPython.core.display import display, HTML
except ImportError:
print("IPython Notebook could not be loaded.")
require_js = """
if (window['d3'] === undefined) {{
require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }});
require(["d3"], function(d3) {{
window.d3 = d3;
{0}
}});
}};
if (window['topojson'] === undefined) {{
require.config(
{{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }}
);
require(["topojson"], function(topojson) {{
window.topojson = topojson;
}});
}};
"""
d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js"
d3_layout_cloud_js_url = "http://wrobstory.github.io/d3-cloud/" "d3.layout.cloud.js"
topojson_js_url = "http://d3js.org/topojson.v1.min.js"
vega_js_url = "http://trifacta.github.com/vega/vega.js"
dep_libs = """$.getScript("%s", function() {
$.getScript("%s", function() {
$.getScript("%s", function() {
$.getScript("%s", function() {
$([IPython.events]).trigger("vega_loaded.vincent");
})
})
})
});""" % (
d3_geo_projection_js_url,
d3_layout_cloud_js_url,
topojson_js_url,
vega_js_url,
)
load_js = require_js.format(dep_libs)
html = "<script>" + load_js + "</script>"
display(HTML(html))

View File

@@ -0,0 +1,382 @@
"""
Utility Routines for Working with Matplotlib Objects
====================================================
"""
import itertools
import io
import base64
import numpy as np
import warnings
import matplotlib
from matplotlib.colors import colorConverter
from matplotlib.path import Path
from matplotlib.markers import MarkerStyle
from matplotlib.transforms import Affine2D
from matplotlib import ticker
def export_color(color):
"""Convert matplotlib color code to hex color or RGBA color"""
if color is None or colorConverter.to_rgba(color)[3] == 0:
return "none"
elif colorConverter.to_rgba(color)[3] == 1:
rgb = colorConverter.to_rgb(color)
return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb))
else:
c = colorConverter.to_rgba(color)
return (
"rgba("
+ ", ".join(str(int(np.round(val * 255))) for val in c[:3])
+ ", "
+ str(c[3])
+ ")"
)
def _many_to_one(input_dict):
"""Convert a many-to-one mapping to a one-to-one mapping"""
return dict((key, val) for keys, val in input_dict.items() for key in keys)
LINESTYLES = _many_to_one(
{
("solid", "-", (None, None)): "none",
("dashed", "--"): "6,6",
("dotted", ":"): "2,2",
("dashdot", "-."): "4,4,2,4",
("", " ", "None", "none"): None,
}
)
def get_dasharray(obj):
"""Get an SVG dash array for the given matplotlib linestyle
Parameters
----------
obj : matplotlib object
The matplotlib line or path object, which must have a get_linestyle()
method which returns a valid matplotlib line code
Returns
-------
dasharray : string
The HTML/SVG dasharray code associated with the object.
"""
if obj.__dict__.get("_dashSeq", None) is not None:
return ",".join(map(str, obj._dashSeq))
else:
ls = obj.get_linestyle()
dasharray = LINESTYLES.get(ls, "not found")
if dasharray == "not found":
warnings.warn(
"line style '{0}' not understood: "
"defaulting to solid line.".format(ls)
)
dasharray = LINESTYLES["solid"]
return dasharray
PATH_DICT = {
Path.LINETO: "L",
Path.MOVETO: "M",
Path.CURVE3: "S",
Path.CURVE4: "C",
Path.CLOSEPOLY: "Z",
}
def SVG_path(path, transform=None, simplify=False):
"""Construct the vertices and SVG codes for the path
Parameters
----------
path : matplotlib.Path object
transform : matplotlib transform (optional)
if specified, the path will be transformed before computing the output.
Returns
-------
vertices : array
The shape (M, 2) array of vertices of the Path. Note that some Path
codes require multiple vertices, so the length of these vertices may
be longer than the list of path codes.
path_codes : list
A length N list of single-character path codes, N <= M. Each code is
a single character, in ['L','M','S','C','Z']. See the standard SVG
path specification for a description of these.
"""
if transform is not None:
path = path.transformed(transform)
vc_tuples = [
(vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code])
for (vertices, path_code) in path.iter_segments(simplify=simplify)
]
if not vc_tuples:
# empty path is a special case
return np.zeros((0, 2)), []
else:
vertices, codes = zip(*vc_tuples)
vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
return vertices, list(codes)
def get_path_style(path, fill=True):
"""Get the style dictionary for matplotlib path objects"""
style = {}
style["alpha"] = path.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["edgecolor"] = export_color(path.get_edgecolor())
if fill:
style["facecolor"] = export_color(path.get_facecolor())
else:
style["facecolor"] = "none"
style["edgewidth"] = path.get_linewidth()
style["dasharray"] = get_dasharray(path)
style["zorder"] = path.get_zorder()
return style
def get_line_style(line):
"""Get the style dictionary for matplotlib line objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["color"] = export_color(line.get_color())
style["linewidth"] = line.get_linewidth()
style["dasharray"] = get_dasharray(line)
style["zorder"] = line.get_zorder()
style["drawstyle"] = line.get_drawstyle()
return style
def get_marker_style(line):
"""Get the style dictionary for matplotlib marker objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["facecolor"] = export_color(line.get_markerfacecolor())
style["edgecolor"] = export_color(line.get_markeredgecolor())
style["edgewidth"] = line.get_markeredgewidth()
style["marker"] = line.get_marker()
markerstyle = MarkerStyle(line.get_marker())
markersize = line.get_markersize()
markertransform = markerstyle.get_transform() + Affine2D().scale(
markersize, -markersize
)
style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform)
style["markersize"] = markersize
style["zorder"] = line.get_zorder()
return style
def get_text_style(text):
"""Return the text style dict for a text instance"""
style = {}
style["alpha"] = text.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["fontsize"] = text.get_size()
style["color"] = export_color(text.get_color())
style["halign"] = text.get_horizontalalignment() # left, center, right
style["valign"] = text.get_verticalalignment() # baseline, center, top
style["malign"] = text._multialignment # text alignment when '\n' in text
style["rotation"] = text.get_rotation()
style["zorder"] = text.get_zorder()
return style
def get_axis_properties(axis):
"""Return the property dictionary for a matplotlib.Axis instance"""
props = {}
label1On = axis._major_tick_kw.get("label1On", True)
if isinstance(axis, matplotlib.axis.XAxis):
if label1On:
props["position"] = "bottom"
else:
props["position"] = "top"
elif isinstance(axis, matplotlib.axis.YAxis):
if label1On:
props["position"] = "left"
else:
props["position"] = "right"
else:
raise ValueError("{0} should be an Axis instance".format(axis))
# Use tick values if appropriate
locator = axis.get_major_locator()
props["nticks"] = len(locator())
if isinstance(locator, ticker.FixedLocator):
props["tickvalues"] = list(locator())
else:
props["tickvalues"] = None
# Find tick formats
formatter = axis.get_major_formatter()
if isinstance(formatter, ticker.NullFormatter):
props["tickformat"] = ""
elif isinstance(formatter, ticker.FixedFormatter):
props["tickformat"] = list(formatter.seq)
elif isinstance(formatter, ticker.FuncFormatter):
props["tickformat"] = list(formatter.func.args[0].values())
elif not any(label.get_visible() for label in axis.get_ticklabels()):
props["tickformat"] = ""
else:
props["tickformat"] = None
# Get axis scale
props["scale"] = axis.get_scale()
# Get major tick label size (assumes that's all we really care about!)
labels = axis.get_ticklabels()
if labels:
props["fontsize"] = labels[0].get_fontsize()
else:
props["fontsize"] = None
# Get associated grid
props["grid"] = get_grid_style(axis)
# get axis visibility
props["visible"] = axis.get_visible()
return props
def get_grid_style(axis):
gridlines = axis.get_gridlines()
if axis._major_tick_kw["gridOn"] and len(gridlines) > 0:
color = export_color(gridlines[0].get_color())
alpha = gridlines[0].get_alpha()
dasharray = get_dasharray(gridlines[0])
return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
else:
return {"gridOn": False}
def get_figure_properties(fig):
return {
"figwidth": fig.get_figwidth(),
"figheight": fig.get_figheight(),
"dpi": fig.dpi,
}
def get_axes_properties(ax):
props = {
"axesbg": export_color(ax.patch.get_facecolor()),
"axesbgalpha": ax.patch.get_alpha(),
"bounds": ax.get_position().bounds,
"dynamic": ax.get_navigate(),
"axison": ax.axison,
"frame_on": ax.get_frame_on(),
"patch_visible": ax.patch.get_visible(),
"axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
}
for axname in ["x", "y"]:
axis = getattr(ax, axname + "axis")
domain = getattr(ax, "get_{0}lim".format(axname))()
lim = domain
if isinstance(axis.converter, matplotlib.dates.DateConverter):
scale = "date"
try:
import pandas as pd
from pandas.tseries.converter import PeriodConverter
except ImportError:
pd = None
if pd is not None and isinstance(axis.converter, PeriodConverter):
_dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain]
domain = [
(d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0)
for d in _dates
]
else:
domain = [
(
d.year,
d.month - 1,
d.day,
d.hour,
d.minute,
d.second,
d.microsecond * 1e-3,
)
for d in matplotlib.dates.num2date(domain)
]
else:
scale = axis.get_scale()
if scale not in ["date", "linear", "log"]:
raise ValueError("Unknown axis scale: " "{0}".format(axis.get_scale()))
props[axname + "scale"] = scale
props[axname + "lim"] = lim
props[axname + "domain"] = domain
return props
def iter_all_children(obj, skipContainers=False):
"""
Returns an iterator over all childen and nested children using
obj's get_children() method
if skipContainers is true, only childless objects are returned.
"""
if hasattr(obj, "get_children") and len(obj.get_children()) > 0:
for child in obj.get_children():
if not skipContainers:
yield child
# could use `yield from` in python 3...
for grandchild in iter_all_children(child, skipContainers):
yield grandchild
else:
yield obj
def get_legend_properties(ax, legend):
handles, labels = ax.get_legend_handles_labels()
visible = legend.get_visible()
return {"handles": handles, "labels": labels, "visible": visible}
def image_to_base64(image):
"""
Convert a matplotlib image to a base64 png representation
Parameters
----------
image : matplotlib image object
The image to be converted.
Returns
-------
image_base64 : string
The UTF8-encoded base64 string representation of the png image.
"""
ax = image.axes
binary_buffer = io.BytesIO()
# image is saved in axes coordinates: we need to temporarily
# set the correct limits to get the correct image
lim = ax.axis()
ax.axis(image.get_extent())
image.write_png(binary_buffer)
ax.axis(lim)
binary_buffer.seek(0)
return base64.b64encode(binary_buffer.read()).decode("utf-8")

View File

@@ -0,0 +1,611 @@
"""
Tools
A module for converting from mpl language to plotly language.
"""
import math
import datetime
import warnings
import matplotlib.dates
def check_bar_match(old_bar, new_bar):
"""Check if two bars belong in the same collection (bar chart).
Positional arguments:
old_bar -- a previously sorted bar dictionary.
new_bar -- a new bar dictionary that needs to be sorted.
"""
tests = []
tests += (new_bar["orientation"] == old_bar["orientation"],)
tests += (new_bar["facecolor"] == old_bar["facecolor"],)
if new_bar["orientation"] == "v":
new_width = new_bar["x1"] - new_bar["x0"]
old_width = old_bar["x1"] - old_bar["x0"]
tests += (new_width - old_width < 0.000001,)
tests += (new_bar["y0"] == old_bar["y0"],)
elif new_bar["orientation"] == "h":
new_height = new_bar["y1"] - new_bar["y0"]
old_height = old_bar["y1"] - old_bar["y0"]
tests += (new_height - old_height < 0.000001,)
tests += (new_bar["x0"] == old_bar["x0"],)
if all(tests):
return True
else:
return False
def check_corners(inner_obj, outer_obj):
inner_corners = inner_obj.get_window_extent().corners()
outer_corners = outer_obj.get_window_extent().corners()
if inner_corners[0][0] < outer_corners[0][0]:
return False
elif inner_corners[0][1] < outer_corners[0][1]:
return False
elif inner_corners[3][0] > outer_corners[3][0]:
return False
elif inner_corners[3][1] > outer_corners[3][1]:
return False
else:
return True
def convert_dash(mpl_dash):
"""Convert mpl line symbol to plotly line symbol and return symbol."""
if mpl_dash in DASH_MAP:
return DASH_MAP[mpl_dash]
else:
dash_array = mpl_dash.split(",")
if len(dash_array) < 2:
return "solid"
# Catch the exception where the off length is zero, in case
# matplotlib 'solid' changes from '10,0' to 'N,0'
if math.isclose(float(dash_array[1]), 0.0):
return "solid"
# If we can't find the dash pattern in the map, convert it
# into custom values in px, e.g. '7,5' -> '7px,5px'
dashpx = ",".join([x + "px" for x in dash_array])
# TODO: rewrite the convert_dash code
# only strings 'solid', 'dashed', etc allowed
if dashpx == "7.4px,3.2px":
dashpx = "dashed"
elif dashpx == "12.8px,3.2px,2.0px,3.2px":
dashpx = "dashdot"
elif dashpx == "2.0px,3.3px":
dashpx = "dotted"
return dashpx
def convert_path(path):
verts = path[0] # may use this later
code = tuple(path[1])
if code in PATH_MAP:
return PATH_MAP[code]
else:
return None
def convert_symbol(mpl_symbol):
"""Convert mpl marker symbol to plotly symbol and return symbol."""
if isinstance(mpl_symbol, list):
symbol = list()
for s in mpl_symbol:
symbol += [convert_symbol(s)]
return symbol
elif mpl_symbol in SYMBOL_MAP:
return SYMBOL_MAP[mpl_symbol]
else:
return "circle" # default
def hex_to_rgb(value):
"""
Change a hex color to an rgb tuple
:param (str|unicode) value: The hex string we want to convert.
:return: (int, int, int) The red, green, blue int-tuple.
Example:
'#FFFFFF' --> (255, 255, 255)
"""
value = value.lstrip("#")
lv = len(value)
return tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
def merge_color_and_opacity(color, opacity):
"""
Merge hex color with an alpha (opacity) to get an rgba tuple.
:param (str|unicode) color: A hex color string.
:param (float|int) opacity: A value [0, 1] for the 'a' in 'rgba'.
:return: (int, int, int, float) The rgba color and alpha tuple.
"""
if color is None: # None can be used as a placeholder, just bail.
return None
rgb_tup = hex_to_rgb(color)
if opacity is None:
return "rgb {}".format(rgb_tup)
rgba_tup = rgb_tup + (opacity,)
return "rgba {}".format(rgba_tup)
def convert_va(mpl_va):
"""Convert mpl vertical alignment word to equivalent HTML word.
Text alignment specifiers from mpl differ very slightly from those used
in HTML. See the VA_MAP for more details.
Positional arguments:
mpl_va -- vertical mpl text alignment spec.
"""
if mpl_va in VA_MAP:
return VA_MAP[mpl_va]
else:
return None # let plotly figure it out!
def convert_x_domain(mpl_plot_bounds, mpl_max_x_bounds):
"""Map x dimension of current plot to plotly's domain space.
The bbox used to locate an axes object in mpl differs from the
method used to locate axes in plotly. The mpl version locates each
axes in the figure so that axes in a single-plot figure might have
the bounds, [0.125, 0.125, 0.775, 0.775] (x0, y0, width, height),
in mpl's figure coordinates. However, the axes all share one space in
plotly such that the domain will always be [0, 0, 1, 1]
(x0, y0, x1, y1). To convert between the two, the mpl figure bounds
need to be mapped to a [0, 1] domain for x and y. The margins set
upon opening a new figure will appropriately match the mpl margins.
Optionally, setting margins=0 and simply copying the domains from
mpl to plotly would place axes appropriately. However,
this would throw off axis and title labeling.
Positional arguments:
mpl_plot_bounds -- the (x0, y0, width, height) params for current ax **
mpl_max_x_bounds -- overall (x0, x1) bounds for all axes **
** these are all specified in mpl figure coordinates
"""
mpl_x_dom = [mpl_plot_bounds[0], mpl_plot_bounds[0] + mpl_plot_bounds[2]]
plotting_width = mpl_max_x_bounds[1] - mpl_max_x_bounds[0]
x0 = (mpl_x_dom[0] - mpl_max_x_bounds[0]) / plotting_width
x1 = (mpl_x_dom[1] - mpl_max_x_bounds[0]) / plotting_width
return [x0, x1]
def convert_y_domain(mpl_plot_bounds, mpl_max_y_bounds):
"""Map y dimension of current plot to plotly's domain space.
The bbox used to locate an axes object in mpl differs from the
method used to locate axes in plotly. The mpl version locates each
axes in the figure so that axes in a single-plot figure might have
the bounds, [0.125, 0.125, 0.775, 0.775] (x0, y0, width, height),
in mpl's figure coordinates. However, the axes all share one space in
plotly such that the domain will always be [0, 0, 1, 1]
(x0, y0, x1, y1). To convert between the two, the mpl figure bounds
need to be mapped to a [0, 1] domain for x and y. The margins set
upon opening a new figure will appropriately match the mpl margins.
Optionally, setting margins=0 and simply copying the domains from
mpl to plotly would place axes appropriately. However,
this would throw off axis and title labeling.
Positional arguments:
mpl_plot_bounds -- the (x0, y0, width, height) params for current ax **
mpl_max_y_bounds -- overall (y0, y1) bounds for all axes **
** these are all specified in mpl figure coordinates
"""
mpl_y_dom = [mpl_plot_bounds[1], mpl_plot_bounds[1] + mpl_plot_bounds[3]]
plotting_height = mpl_max_y_bounds[1] - mpl_max_y_bounds[0]
y0 = (mpl_y_dom[0] - mpl_max_y_bounds[0]) / plotting_height
y1 = (mpl_y_dom[1] - mpl_max_y_bounds[0]) / plotting_height
return [y0, y1]
def display_to_paper(x, y, layout):
"""Convert mpl display coordinates to plotly paper coordinates.
Plotly references object positions with an (x, y) coordinate pair in either
'data' or 'paper' coordinates which reference actual data in a plot or
the entire plotly axes space where the bottom-left of the bottom-left
plot has the location (x, y) = (0, 0) and the top-right of the top-right
plot has the location (x, y) = (1, 1). Display coordinates in mpl reference
objects with an (x, y) pair in pixel coordinates, where the bottom-left
corner is at the location (x, y) = (0, 0) and the top-right corner is at
the location (x, y) = (figwidth*dpi, figheight*dpi). Here, figwidth and
figheight are in inches and dpi are the dots per inch resolution.
"""
num_x = x - layout["margin"]["l"]
den_x = layout["width"] - (layout["margin"]["l"] + layout["margin"]["r"])
num_y = y - layout["margin"]["b"]
den_y = layout["height"] - (layout["margin"]["b"] + layout["margin"]["t"])
return num_x / den_x, num_y / den_y
def get_axes_bounds(fig):
"""Return the entire axes space for figure.
An axes object in mpl is specified by its relation to the figure where
(0,0) corresponds to the bottom-left part of the figure and (1,1)
corresponds to the top-right. Margins exist in matplotlib because axes
objects normally don't go to the edges of the figure.
In plotly, the axes area (where all subplots go) is always specified with
the domain [0,1] for both x and y. This function finds the smallest box,
specified by two points, that all of the mpl axes objects fit into. This
box is then used to map mpl axes domains to plotly axes domains.
"""
x_min, x_max, y_min, y_max = [], [], [], []
for axes_obj in fig.get_axes():
bounds = axes_obj.get_position().bounds
x_min.append(bounds[0])
x_max.append(bounds[0] + bounds[2])
y_min.append(bounds[1])
y_max.append(bounds[1] + bounds[3])
x_min, y_min, x_max, y_max = min(x_min), min(y_min), max(x_max), max(y_max)
return (x_min, x_max), (y_min, y_max)
def get_axis_mirror(main_spine, mirror_spine):
if main_spine and mirror_spine:
return "ticks"
elif main_spine and not mirror_spine:
return False
elif not main_spine and mirror_spine:
return False # can't handle this case yet!
else:
return False # nuttin'!
def get_bar_gap(bar_starts, bar_ends, tol=1e-10):
if len(bar_starts) == len(bar_ends) and len(bar_starts) > 1:
sides1 = bar_starts[1:]
sides2 = bar_ends[:-1]
gaps = [s2 - s1 for s2, s1 in zip(sides1, sides2)]
gap0 = gaps[0]
uniform = all([abs(gap0 - gap) < tol for gap in gaps])
if uniform:
return gap0
def convert_rgba_array(color_list):
clean_color_list = list()
for c in color_list:
clean_color_list += [
(dict(r=int(c[0] * 255), g=int(c[1] * 255), b=int(c[2] * 255), a=c[3]))
]
plotly_colors = list()
for rgba in clean_color_list:
plotly_colors += ["rgba({r},{g},{b},{a})".format(**rgba)]
if len(plotly_colors) == 1:
return plotly_colors[0]
else:
return plotly_colors
def convert_path_array(path_array):
symbols = list()
for path in path_array:
symbols += [convert_path(path)]
if len(symbols) == 1:
return symbols[0]
else:
return symbols
def convert_linewidth_array(width_array):
if len(width_array) == 1:
return width_array[0]
else:
return width_array
def convert_size_array(size_array):
size = [math.sqrt(s) for s in size_array]
if len(size) == 1:
return size[0]
else:
return size
def get_markerstyle_from_collection(props):
markerstyle = dict(
alpha=None,
facecolor=convert_rgba_array(props["styles"]["facecolor"]),
marker=convert_path_array(props["paths"]),
edgewidth=convert_linewidth_array(props["styles"]["linewidth"]),
# markersize=convert_size_array(props['styles']['size']), # TODO!
markersize=convert_size_array(props["mplobj"].get_sizes()),
edgecolor=convert_rgba_array(props["styles"]["edgecolor"]),
)
return markerstyle
def get_rect_xmin(data):
"""Find minimum x value from four (x,y) vertices."""
return min(data[0][0], data[1][0], data[2][0], data[3][0])
def get_rect_xmax(data):
"""Find maximum x value from four (x,y) vertices."""
return max(data[0][0], data[1][0], data[2][0], data[3][0])
def get_rect_ymin(data):
"""Find minimum y value from four (x,y) vertices."""
return min(data[0][1], data[1][1], data[2][1], data[3][1])
def get_rect_ymax(data):
"""Find maximum y value from four (x,y) vertices."""
return max(data[0][1], data[1][1], data[2][1], data[3][1])
def get_spine_visible(ax, spine_key):
"""Return some spine parameters for the spine, `spine_key`."""
spine = ax.spines[spine_key]
ax_frame_on = ax.get_frame_on()
position = spine._position or ("outward", 0.0)
if isinstance(position, str):
if position == "center":
position = ("axes", 0.5)
elif position == "zero":
position = ("data", 0)
position_type, amount = position
if position_type == "outward" and amount == 0:
spine_frame_like = True
else:
spine_frame_like = False
if not spine.get_visible():
return False
elif not spine._edgecolor[-1]: # user's may have set edgecolor alpha==0
return False
elif not ax_frame_on and spine_frame_like:
return False
elif ax_frame_on and spine_frame_like:
return True
elif not ax_frame_on and not spine_frame_like:
return True # we've already checked for that it's visible.
else:
return False # oh man, and i thought we exhausted the options...
def is_bar(bar_containers, **props):
"""A test to decide whether a path is a bar from a vertical bar chart."""
# is this patch in a bar container?
for container in bar_containers:
if props["mplobj"] in container:
return True
return False
def make_bar(**props):
"""Make an intermediate bar dictionary.
This creates a bar dictionary which aids in the comparison of new bars to
old bars from other bar chart (patch) collections. This is not the
dictionary that needs to get passed to plotly as a data dictionary. That
happens in PlotlyRenderer in that class's draw_bar method. In other
words, this dictionary describes a SINGLE bar, whereas, plotly will
require a set of bars to be passed in a data dictionary.
"""
return {
"bar": props["mplobj"],
"x0": get_rect_xmin(props["data"]),
"y0": get_rect_ymin(props["data"]),
"x1": get_rect_xmax(props["data"]),
"y1": get_rect_ymax(props["data"]),
"alpha": props["style"]["alpha"],
"edgecolor": props["style"]["edgecolor"],
"facecolor": props["style"]["facecolor"],
"edgewidth": props["style"]["edgewidth"],
"dasharray": props["style"]["dasharray"],
"zorder": props["style"]["zorder"],
}
def prep_ticks(ax, index, ax_type, props):
"""Prepare axis obj belonging to axes obj.
positional arguments:
ax - the mpl axes instance
index - the index of the axis in `props`
ax_type - 'x' or 'y' (for now)
props - an mplexporter poperties dictionary
"""
axis_dict = dict()
if ax_type == "x":
axis = ax.get_xaxis()
elif ax_type == "y":
axis = ax.get_yaxis()
else:
return dict() # whoops!
scale = props["axes"][index]["scale"]
if scale == "linear":
# get tick location information
try:
tickvalues = props["axes"][index]["tickvalues"]
tick0 = tickvalues[0]
dticks = [
round(tickvalues[i] - tickvalues[i - 1], 12)
for i in range(1, len(tickvalues) - 1)
]
if all([dticks[i] == dticks[i - 1] for i in range(1, len(dticks) - 1)]):
dtick = tickvalues[1] - tickvalues[0]
else:
warnings.warn(
"'linear' {0}-axis tick spacing not even, "
"ignoring mpl tick formatting.".format(ax_type)
)
raise TypeError
except (IndexError, TypeError):
axis_dict["nticks"] = props["axes"][index]["nticks"]
else:
axis_dict["tick0"] = tick0
axis_dict["dtick"] = dtick
axis_dict["tickmode"] = None
elif scale == "log":
try:
axis_dict["tick0"] = props["axes"][index]["tickvalues"][0]
axis_dict["dtick"] = (
props["axes"][index]["tickvalues"][1]
- props["axes"][index]["tickvalues"][0]
)
axis_dict["tickmode"] = None
except (IndexError, TypeError):
axis_dict = dict(nticks=props["axes"][index]["nticks"])
base = axis.get_transform().base
if base == 10:
if ax_type == "x":
axis_dict["range"] = [
math.log10(props["xlim"][0]),
math.log10(props["xlim"][1]),
]
elif ax_type == "y":
axis_dict["range"] = [
math.log10(props["ylim"][0]),
math.log10(props["ylim"][1]),
]
else:
axis_dict = dict(range=None, type="linear")
warnings.warn(
"Converted non-base10 {0}-axis log scale to 'linear'" "".format(ax_type)
)
else:
return dict()
# get tick label formatting information
formatter = axis.get_major_formatter().__class__.__name__
if ax_type == "x" and "DateFormatter" in formatter:
axis_dict["type"] = "date"
try:
axis_dict["tick0"] = mpl_dates_to_datestrings(axis_dict["tick0"], formatter)
except KeyError:
pass
finally:
axis_dict.pop("dtick", None)
axis_dict.pop("tickmode", None)
axis_dict["range"] = mpl_dates_to_datestrings(props["xlim"], formatter)
if formatter == "LogFormatterMathtext":
axis_dict["exponentformat"] = "e"
return axis_dict
def prep_xy_axis(ax, props, x_bounds, y_bounds):
xaxis = dict(
type=props["axes"][0]["scale"],
range=list(props["xlim"]),
showgrid=props["axes"][0]["grid"]["gridOn"],
domain=convert_x_domain(props["bounds"], x_bounds),
side=props["axes"][0]["position"],
tickfont=dict(size=props["axes"][0]["fontsize"]),
)
xaxis.update(prep_ticks(ax, 0, "x", props))
yaxis = dict(
type=props["axes"][1]["scale"],
range=list(props["ylim"]),
showgrid=props["axes"][1]["grid"]["gridOn"],
domain=convert_y_domain(props["bounds"], y_bounds),
side=props["axes"][1]["position"],
tickfont=dict(size=props["axes"][1]["fontsize"]),
)
yaxis.update(prep_ticks(ax, 1, "y", props))
return xaxis, yaxis
def mpl_dates_to_datestrings(dates, mpl_formatter):
"""Convert matplotlib dates to iso-formatted-like time strings.
Plotly's accepted format: "YYYY-MM-DD HH:MM:SS" (e.g., 2001-01-01 00:00:00)
Info on mpl dates: http://matplotlib.org/api/dates_api.html
"""
_dates = dates
# this is a pandas datetime formatter, times show up in floating point days
# since the epoch (1970-01-01T00:00:00+00:00)
if mpl_formatter == "TimeSeries_DateFormatter":
try:
dates = matplotlib.dates.epoch2num([date * 24 * 60 * 60 for date in dates])
dates = matplotlib.dates.num2date(dates)
except:
return _dates
# the rest of mpl dates are in floating point days since
# (0001-01-01T00:00:00+00:00) + 1. I.e., (0001-01-01T00:00:00+00:00) == 1.0
# according to mpl --> try num2date(1)
else:
try:
dates = matplotlib.dates.num2date(dates)
except:
return _dates
time_stings = [
" ".join(date.isoformat().split("+")[0].split("T")) for date in dates
]
return time_stings
# dashed is dash in matplotlib
DASH_MAP = {
"10,0": "solid",
"6,6": "dash",
"2,2": "circle",
"4,4,2,4": "dashdot",
"none": "solid",
"7.4,3.2": "dash",
}
PATH_MAP = {
("M", "C", "C", "C", "C", "C", "C", "C", "C", "Z"): "o",
("M", "L", "L", "L", "L", "L", "L", "L", "L", "L", "Z"): "*",
("M", "L", "L", "L", "L", "L", "L", "L", "Z"): "8",
("M", "L", "L", "L", "L", "L", "Z"): "h",
("M", "L", "L", "L", "L", "Z"): "p",
("M", "L", "M", "L", "M", "L"): "1",
("M", "L", "L", "L", "Z"): "s",
("M", "L", "M", "L"): "+",
("M", "L", "L", "Z"): "^",
("M", "L"): "|",
}
SYMBOL_MAP = {
"o": "circle",
"v": "triangle-down",
"^": "triangle-up",
"<": "triangle-left",
">": "triangle-right",
"s": "square",
"+": "cross",
"x": "x",
"*": "star",
"D": "diamond",
"d": "diamond",
}
VA_MAP = {"center": "middle", "baseline": "bottom", "top": "top"}

View File

@@ -0,0 +1,865 @@
"""
Renderer Module
This module defines the PlotlyRenderer class and a single function,
fig_to_plotly, which is intended to be the main way that user's will interact
with the matplotlylib package.
"""
from __future__ import absolute_import
import warnings
import plotly.graph_objs as go
from plotly.matplotlylib.mplexporter import Renderer
from plotly.matplotlylib import mpltools
# Warning format
def warning_on_one_line(msg, category, filename, lineno, file=None, line=None):
return "%s:%s: %s:\n\n%s\n\n" % (filename, lineno, category.__name__, msg)
warnings.formatwarning = warning_on_one_line
class PlotlyRenderer(Renderer):
"""A renderer class inheriting from base for rendering mpl plots in plotly.
A renderer class to be used with an exporter for rendering matplotlib
plots in Plotly. This module defines the PlotlyRenderer class which handles
the creation of the JSON structures that get sent to plotly.
All class attributes available are defined in __init__().
Basic Usage:
# (mpl code) #
fig = gcf()
renderer = PlotlyRenderer(fig)
exporter = Exporter(renderer)
exporter.run(fig) # ... et voila
"""
def __init__(self):
"""Initialize PlotlyRenderer obj.
PlotlyRenderer obj is called on by an Exporter object to draw
matplotlib objects like figures, axes, text, etc.
All class attributes are listed here in the __init__ method.
"""
self.plotly_fig = go.Figure()
self.mpl_fig = None
self.current_mpl_ax = None
self.bar_containers = None
self.current_bars = []
self.axis_ct = 0
self.x_is_mpl_date = False
self.mpl_x_bounds = (0, 1)
self.mpl_y_bounds = (0, 1)
self.msg = "Initialized PlotlyRenderer\n"
def open_figure(self, fig, props):
"""Creates a new figure by beginning to fill out layout dict.
The 'autosize' key is set to false so that the figure will mirror
sizes set by mpl. The 'hovermode' key controls what shows up when you
mouse around a figure in plotly, it's set to show the 'closest' point.
Positional agurments:
fig -- a matplotlib.figure.Figure object.
props.keys(): [
'figwidth',
'figheight',
'dpi'
]
"""
self.msg += "Opening figure\n"
self.mpl_fig = fig
self.plotly_fig["layout"] = go.Layout(
width=int(props["figwidth"] * props["dpi"]),
height=int(props["figheight"] * props["dpi"]),
autosize=False,
hovermode="closest",
)
self.mpl_x_bounds, self.mpl_y_bounds = mpltools.get_axes_bounds(fig)
margin = go.layout.Margin(
l=int(self.mpl_x_bounds[0] * self.plotly_fig["layout"]["width"]),
r=int((1 - self.mpl_x_bounds[1]) * self.plotly_fig["layout"]["width"]),
t=int((1 - self.mpl_y_bounds[1]) * self.plotly_fig["layout"]["height"]),
b=int(self.mpl_y_bounds[0] * self.plotly_fig["layout"]["height"]),
pad=0,
)
self.plotly_fig["layout"]["margin"] = margin
def close_figure(self, fig):
"""Closes figure by cleaning up data and layout dictionaries.
The PlotlyRenderer's job is to create an appropriate set of data and
layout dictionaries. When the figure is closed, some cleanup and
repair is necessary. This method removes inappropriate dictionary
entries, freeing up Plotly to use defaults and best judgements to
complete the entries. This method is called by an Exporter object.
Positional arguments:
fig -- a matplotlib.figure.Figure object.
"""
self.plotly_fig["layout"]["showlegend"] = False
self.msg += "Closing figure\n"
def open_axes(self, ax, props):
"""Setup a new axes object (subplot in plotly).
Plotly stores information about subplots in different 'xaxis' and
'yaxis' objects which are numbered. These are just dictionaries
included in the layout dictionary. This function takes information
from the Exporter, fills in appropriate dictionary entries,
and updates the layout dictionary. PlotlyRenderer keeps track of the
number of plots by incrementing the axis_ct attribute.
Setting the proper plot domain in plotly is a bit tricky. Refer to
the documentation for mpltools.convert_x_domain and
mpltools.convert_y_domain.
Positional arguments:
ax -- an mpl axes object. This will become a subplot in plotly.
props.keys() -- [
'axesbg', (background color for axes obj)
'axesbgalpha', (alpha, or opacity for background)
'bounds', ((x0, y0, width, height) for axes)
'dynamic', (zoom/pan-able?)
'axes', (list: [xaxis, yaxis])
'xscale', (log, linear, or date)
'yscale',
'xlim', (range limits for x)
'ylim',
'xdomain' (xdomain=xlim, unless it's a date)
'ydomain'
]
"""
self.msg += " Opening axes\n"
self.current_mpl_ax = ax
self.bar_containers = [
c
for c in ax.containers # empty is OK
if c.__class__.__name__ == "BarContainer"
]
self.current_bars = []
self.axis_ct += 1
# set defaults in axes
xaxis = go.layout.XAxis(
anchor="y{0}".format(self.axis_ct), zeroline=False, ticks="inside"
)
yaxis = go.layout.YAxis(
anchor="x{0}".format(self.axis_ct), zeroline=False, ticks="inside"
)
# update defaults with things set in mpl
mpl_xaxis, mpl_yaxis = mpltools.prep_xy_axis(
ax=ax, props=props, x_bounds=self.mpl_x_bounds, y_bounds=self.mpl_y_bounds
)
xaxis.update(mpl_xaxis)
yaxis.update(mpl_yaxis)
bottom_spine = mpltools.get_spine_visible(ax, "bottom")
top_spine = mpltools.get_spine_visible(ax, "top")
left_spine = mpltools.get_spine_visible(ax, "left")
right_spine = mpltools.get_spine_visible(ax, "right")
xaxis["mirror"] = mpltools.get_axis_mirror(bottom_spine, top_spine)
yaxis["mirror"] = mpltools.get_axis_mirror(left_spine, right_spine)
xaxis["showline"] = bottom_spine
yaxis["showline"] = top_spine
# put axes in our figure
self.plotly_fig["layout"]["xaxis{0}".format(self.axis_ct)] = xaxis
self.plotly_fig["layout"]["yaxis{0}".format(self.axis_ct)] = yaxis
# let all subsequent dates be handled properly if required
if "type" in dir(xaxis) and xaxis["type"] == "date":
self.x_is_mpl_date = True
def close_axes(self, ax):
"""Close the axes object and clean up.
Bars from bar charts are given to PlotlyRenderer one-by-one,
thus they need to be taken care of at the close of each axes object.
The self.current_bars variable should be empty unless a bar
chart has been created.
Positional arguments:
ax -- an mpl axes object, not required at this time.
"""
self.draw_bars(self.current_bars)
self.msg += " Closing axes\n"
self.x_is_mpl_date = False
def draw_bars(self, bars):
# sort bars according to bar containers
mpl_traces = []
for container in self.bar_containers:
mpl_traces.append(
[
bar_props
for bar_props in self.current_bars
if bar_props["mplobj"] in container
]
)
for trace in mpl_traces:
self.draw_bar(trace)
def draw_bar(self, coll):
"""Draw a collection of similar patches as a bar chart.
After bars are sorted, an appropriate data dictionary must be created
to tell plotly about this data. Just like draw_line or draw_markers,
draw_bar translates patch/path information into something plotly
understands.
Positional arguments:
patch_coll -- a collection of patches to be drawn as a bar chart.
"""
tol = 1e-10
trace = [mpltools.make_bar(**bar_props) for bar_props in coll]
widths = [bar_props["x1"] - bar_props["x0"] for bar_props in trace]
heights = [bar_props["y1"] - bar_props["y0"] for bar_props in trace]
vertical = abs(sum(widths[0] - widths[iii] for iii in range(len(widths)))) < tol
horizontal = (
abs(sum(heights[0] - heights[iii] for iii in range(len(heights)))) < tol
)
if vertical and horizontal:
# Check for monotonic x. Can't both be true!
x_zeros = [bar_props["x0"] for bar_props in trace]
if all(
(x_zeros[iii + 1] > x_zeros[iii] for iii in range(len(x_zeros[:-1])))
):
orientation = "v"
else:
orientation = "h"
elif vertical:
orientation = "v"
else:
orientation = "h"
if orientation == "v":
self.msg += " Attempting to draw a vertical bar chart\n"
old_heights = [bar_props["y1"] for bar_props in trace]
for bar in trace:
bar["y0"], bar["y1"] = 0, bar["y1"] - bar["y0"]
new_heights = [bar_props["y1"] for bar_props in trace]
# check if we're stacked or not...
for old, new in zip(old_heights, new_heights):
if abs(old - new) > tol:
self.plotly_fig["layout"]["barmode"] = "stack"
self.plotly_fig["layout"]["hovermode"] = "x"
x = [bar["x0"] + (bar["x1"] - bar["x0"]) / 2 for bar in trace]
y = [bar["y1"] for bar in trace]
bar_gap = mpltools.get_bar_gap(
[bar["x0"] for bar in trace], [bar["x1"] for bar in trace]
)
if self.x_is_mpl_date:
x = [bar["x0"] for bar in trace]
formatter = (
self.current_mpl_ax.get_xaxis()
.get_major_formatter()
.__class__.__name__
)
x = mpltools.mpl_dates_to_datestrings(x, formatter)
else:
self.msg += " Attempting to draw a horizontal bar chart\n"
old_rights = [bar_props["x1"] for bar_props in trace]
for bar in trace:
bar["x0"], bar["x1"] = 0, bar["x1"] - bar["x0"]
new_rights = [bar_props["x1"] for bar_props in trace]
# check if we're stacked or not...
for old, new in zip(old_rights, new_rights):
if abs(old - new) > tol:
self.plotly_fig["layout"]["barmode"] = "stack"
self.plotly_fig["layout"]["hovermode"] = "y"
x = [bar["x1"] for bar in trace]
y = [bar["y0"] + (bar["y1"] - bar["y0"]) / 2 for bar in trace]
bar_gap = mpltools.get_bar_gap(
[bar["y0"] for bar in trace], [bar["y1"] for bar in trace]
)
bar = go.Bar(
orientation=orientation,
x=x,
y=y,
xaxis="x{0}".format(self.axis_ct),
yaxis="y{0}".format(self.axis_ct),
opacity=trace[0]["alpha"], # TODO: get all alphas if array?
marker=go.bar.Marker(
color=trace[0]["facecolor"], # TODO: get all
line=dict(width=trace[0]["edgewidth"]),
),
) # TODO ditto
if len(bar["x"]) > 1:
self.msg += " Heck yeah, I drew that bar chart\n"
self.plotly_fig.add_trace(bar),
if bar_gap is not None:
self.plotly_fig["layout"]["bargap"] = bar_gap
else:
self.msg += " Bar chart not drawn\n"
warnings.warn(
"found box chart data with length <= 1, "
"assuming data redundancy, not plotting."
)
def draw_legend_shapes(self, mode, shape, **props):
"""Create a shape that matches lines or markers in legends.
Main issue is that path for circles do not render, so we have to use 'circle'
instead of 'path'.
"""
for single_mode in mode.split("+"):
x = props["data"][0][0]
y = props["data"][0][1]
if single_mode == "markers" and props.get("markerstyle"):
size = shape.pop("size", 6)
symbol = shape.pop("symbol")
# aligning to "center"
x0 = 0
y0 = 0
x1 = size
y1 = size
markerpath = props["markerstyle"].get("markerpath")
if markerpath is None and symbol != "circle":
self.msg += (
"not sure how to handle this marker without a valid path\n"
)
return
# marker path to SVG path conversion
path = " ".join(
[f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])]
)
if symbol == "circle":
# symbols like . and o in matplotlib, use circle
# plotly also maps many other markers to circle, such as 1,8 and p
path = None
shape_type = "circle"
x0 = -size / 2
y0 = size / 2
x1 = size / 2
y1 = size + size / 2
else:
# triangles, star etc
shape_type = "path"
legend_shape = go.layout.Shape(
type=shape_type,
xref="paper",
yref="paper",
x0=x0,
y0=y0,
x1=x1,
y1=y1,
xsizemode="pixel",
ysizemode="pixel",
xanchor=x,
yanchor=y,
path=path,
**shape,
)
elif single_mode == "lines":
mode = "line"
x1 = props["data"][1][0]
y1 = props["data"][1][1]
legend_shape = go.layout.Shape(
type=mode,
xref="paper",
yref="paper",
x0=x,
y0=y + 0.02,
x1=x1,
y1=y1 + 0.02,
**shape,
)
else:
self.msg += "not sure how to handle this element\n"
return
self.plotly_fig.add_shape(legend_shape)
self.msg += " Heck yeah, I drew that shape\n"
def draw_marked_line(self, **props):
"""Create a data dict for a line obj.
This will draw 'lines', 'markers', or 'lines+markers'. For legend elements,
this will use layout.shapes, so they can be positioned with paper refs.
props.keys() -- [
'coordinates', ('data', 'axes', 'figure', or 'display')
'data', (a list of xy pairs)
'mplobj', (the matplotlib.lines.Line2D obj being rendered)
'label', (the name of the Line2D obj being rendered)
'linestyle', (linestyle dict, can be None, see below)
'markerstyle', (markerstyle dict, can be None, see below)
]
props['linestyle'].keys() -- [
'alpha', (opacity of Line2D obj)
'color', (color of the line if it exists, not the marker)
'linewidth',
'dasharray', (code for linestyle, see DASH_MAP in mpltools.py)
'zorder', (viewing precedence when stacked with other objects)
]
props['markerstyle'].keys() -- [
'alpha', (opacity of Line2D obj)
'marker', (the mpl marker symbol, see SYMBOL_MAP in mpltools.py)
'facecolor', (color of the marker face)
'edgecolor', (color of the marker edge)
'edgewidth', (width of marker edge)
'markerpath', (an SVG path for drawing the specified marker)
'zorder', (viewing precedence when stacked with other objects)
]
"""
self.msg += " Attempting to draw a line "
line, marker, shape = {}, {}, {}
if props["linestyle"] and props["markerstyle"]:
self.msg += "... with both lines+markers\n"
mode = "lines+markers"
elif props["linestyle"]:
self.msg += "... with just lines\n"
mode = "lines"
elif props["markerstyle"]:
self.msg += "... with just markers\n"
mode = "markers"
if props["linestyle"]:
color = mpltools.merge_color_and_opacity(
props["linestyle"]["color"], props["linestyle"]["alpha"]
)
if props["coordinates"] == "data":
line = go.scatter.Line(
color=color,
width=props["linestyle"]["linewidth"],
dash=mpltools.convert_dash(props["linestyle"]["dasharray"]),
)
else:
shape = dict(
line=dict(
color=color,
width=props["linestyle"]["linewidth"],
dash=mpltools.convert_dash(props["linestyle"]["dasharray"]),
)
)
if props["markerstyle"]:
if props["coordinates"] == "data":
marker = go.scatter.Marker(
opacity=props["markerstyle"]["alpha"],
color=props["markerstyle"]["facecolor"],
symbol=mpltools.convert_symbol(props["markerstyle"]["marker"]),
size=props["markerstyle"]["markersize"],
line=dict(
color=props["markerstyle"]["edgecolor"],
width=props["markerstyle"]["edgewidth"],
),
)
else:
shape = dict(
opacity=props["markerstyle"]["alpha"],
fillcolor=props["markerstyle"]["facecolor"],
symbol=mpltools.convert_symbol(props["markerstyle"]["marker"]),
size=props["markerstyle"]["markersize"],
line=dict(
color=props["markerstyle"]["edgecolor"],
width=props["markerstyle"]["edgewidth"],
),
)
if props["coordinates"] == "data":
marked_line = go.Scatter(
mode=mode,
name=(
str(props["label"])
if isinstance(props["label"], str)
else props["label"]
),
x=[xy_pair[0] for xy_pair in props["data"]],
y=[xy_pair[1] for xy_pair in props["data"]],
xaxis="x{0}".format(self.axis_ct),
yaxis="y{0}".format(self.axis_ct),
line=line,
marker=marker,
)
if self.x_is_mpl_date:
formatter = (
self.current_mpl_ax.get_xaxis()
.get_major_formatter()
.__class__.__name__
)
marked_line["x"] = mpltools.mpl_dates_to_datestrings(
marked_line["x"], formatter
)
self.plotly_fig.add_trace(marked_line),
self.msg += " Heck yeah, I drew that line\n"
elif props["coordinates"] == "axes":
# dealing with legend graphical elements
self.draw_legend_shapes(mode=mode, shape=shape, **props)
else:
self.msg += " Line didn't have 'data' coordinates, " "not drawing\n"
warnings.warn(
"Bummer! Plotly can currently only draw Line2D "
"objects from matplotlib that are in 'data' "
"coordinates!"
)
def draw_image(self, **props):
"""Draw image.
Not implemented yet!
"""
self.msg += " Attempting to draw image\n"
self.msg += " Not drawing image\n"
warnings.warn(
"Aw. Snap! You're gonna have to hold off on "
"the selfies for now. Plotly can't import "
"images from matplotlib yet!"
)
def draw_path_collection(self, **props):
"""Add a path collection to data list as a scatter plot.
Current implementation defaults such collections as scatter plots.
Matplotlib supports collections that have many of the same parameters
in common like color, size, path, etc. However, they needn't all be
the same. Plotly does not currently support such functionality and
therefore, the style for the first object is taken and used to define
the remaining paths in the collection.
props.keys() -- [
'paths', (structure: [vertices, path_code])
'path_coordinates', ('data', 'axes', 'figure', or 'display')
'path_transforms', (mpl transform, including Affine2D matrix)
'offsets', (offset from axes, helpful if in 'data')
'offset_coordinates', ('data', 'axes', 'figure', or 'display')
'offset_order',
'styles', (style dict, see below)
'mplobj' (the collection obj being drawn)
]
props['styles'].keys() -- [
'linewidth', (one or more linewidths)
'facecolor', (one or more facecolors for path)
'edgecolor', (one or more edgecolors for path)
'alpha', (one or more opacites for path)
'zorder', (precedence when stacked)
]
"""
self.msg += " Attempting to draw a path collection\n"
if props["offset_coordinates"] == "data":
markerstyle = mpltools.get_markerstyle_from_collection(props)
scatter_props = {
"coordinates": "data",
"data": props["offsets"],
"label": None,
"markerstyle": markerstyle,
"linestyle": None,
}
self.msg += " Drawing path collection as markers\n"
self.draw_marked_line(**scatter_props)
else:
self.msg += " Path collection not linked to 'data', " "not drawing\n"
warnings.warn(
"Dang! That path collection is out of this "
"world. I totally don't know what to do with "
"it yet! Plotly can only import path "
"collections linked to 'data' coordinates"
)
def draw_path(self, **props):
"""Draw path, currently only attempts to draw bar charts.
This function attempts to sort a given path into a collection of
horizontal or vertical bar charts. Most of the actual code takes
place in functions from mpltools.py.
props.keys() -- [
'data', (a list of verticies for the path)
'coordinates', ('data', 'axes', 'figure', or 'display')
'pathcodes', (code for the path, structure: ['M', 'L', 'Z', etc.])
'style', (style dict, see below)
'mplobj' (the mpl path object)
]
props['style'].keys() -- [
'alpha', (opacity of path obj)
'edgecolor',
'facecolor',
'edgewidth',
'dasharray', (style for path's enclosing line)
'zorder' (precedence of obj when stacked)
]
"""
self.msg += " Attempting to draw a path\n"
is_bar = mpltools.is_bar(self.current_mpl_ax.containers, **props)
if is_bar:
self.current_bars += [props]
else:
self.msg += " This path isn't a bar, not drawing\n"
warnings.warn(
"I found a path object that I don't think is part "
"of a bar chart. Ignoring."
)
def draw_text(self, **props):
"""Create an annotation dict for a text obj.
Currently, plotly uses either 'page' or 'data' to reference
annotation locations. These refer to 'display' and 'data',
respectively for the 'coordinates' key used in the Exporter.
Appropriate measures are taken to transform text locations to
reference one of these two options.
props.keys() -- [
'text', (actual content string, not the text obj)
'position', (an x, y pair, not an mpl Bbox)
'coordinates', ('data', 'axes', 'figure', 'display')
'text_type', ('title', 'xlabel', or 'ylabel')
'style', (style dict, see below)
'mplobj' (actual mpl text object)
]
props['style'].keys() -- [
'alpha', (opacity of text)
'fontsize', (size in points of text)
'color', (hex color)
'halign', (horizontal alignment, 'left', 'center', or 'right')
'valign', (vertical alignment, 'baseline', 'center', or 'top')
'rotation',
'zorder', (precedence of text when stacked with other objs)
]
"""
self.msg += " Attempting to draw an mpl text object\n"
if not mpltools.check_corners(props["mplobj"], self.mpl_fig):
warnings.warn(
"Looks like the annotation(s) you are trying \n"
"to draw lies/lay outside the given figure size.\n\n"
"Therefore, the resulting Plotly figure may not be \n"
"large enough to view the full text. To adjust \n"
"the size of the figure, use the 'width' and \n"
"'height' keys in the Layout object. Alternatively,\n"
"use the Margin object to adjust the figure's margins."
)
align = props["mplobj"]._multialignment
if not align:
align = props["style"]["halign"] # mpl default
if "annotations" not in self.plotly_fig["layout"]:
self.plotly_fig["layout"]["annotations"] = []
if props["text_type"] == "xlabel":
self.msg += " Text object is an xlabel\n"
self.draw_xlabel(**props)
elif props["text_type"] == "ylabel":
self.msg += " Text object is a ylabel\n"
self.draw_ylabel(**props)
elif props["text_type"] == "title":
self.msg += " Text object is a title\n"
self.draw_title(**props)
else: # just a regular text annotation...
self.msg += " Text object is a normal annotation\n"
if props["coordinates"] != "data":
self.msg += (
" Text object isn't linked to 'data' " "coordinates\n"
)
x_px, y_px = (
props["mplobj"].get_transform().transform(props["position"])
)
x, y = mpltools.display_to_paper(x_px, y_px, self.plotly_fig["layout"])
xref = "paper"
yref = "paper"
xanchor = props["style"]["halign"] # no difference here!
yanchor = mpltools.convert_va(props["style"]["valign"])
else:
self.msg += " Text object is linked to 'data' " "coordinates\n"
x, y = props["position"]
axis_ct = self.axis_ct
xaxis = self.plotly_fig["layout"]["xaxis{0}".format(axis_ct)]
yaxis = self.plotly_fig["layout"]["yaxis{0}".format(axis_ct)]
if (
xaxis["range"][0] < x < xaxis["range"][1]
and yaxis["range"][0] < y < yaxis["range"][1]
):
xref = "x{0}".format(self.axis_ct)
yref = "y{0}".format(self.axis_ct)
else:
self.msg += (
" Text object is outside "
"plotting area, making 'paper' reference.\n"
)
x_px, y_px = (
props["mplobj"].get_transform().transform(props["position"])
)
x, y = mpltools.display_to_paper(
x_px, y_px, self.plotly_fig["layout"]
)
xref = "paper"
yref = "paper"
xanchor = props["style"]["halign"] # no difference here!
yanchor = mpltools.convert_va(props["style"]["valign"])
annotation = go.layout.Annotation(
text=(
str(props["text"])
if isinstance(props["text"], str)
else props["text"]
),
opacity=props["style"]["alpha"],
x=x,
y=y,
xref=xref,
yref=yref,
align=align,
xanchor=xanchor,
yanchor=yanchor,
showarrow=False, # change this later?
font=go.layout.annotation.Font(
color=props["style"]["color"], size=props["style"]["fontsize"]
),
)
self.plotly_fig["layout"]["annotations"] += (annotation,)
self.msg += " Heck, yeah I drew that annotation\n"
def draw_title(self, **props):
"""Add a title to the current subplot in layout dictionary.
If there exists more than a single plot in the figure, titles revert
to 'page'-referenced annotations.
props.keys() -- [
'text', (actual content string, not the text obj)
'position', (an x, y pair, not an mpl Bbox)
'coordinates', ('data', 'axes', 'figure', 'display')
'text_type', ('title', 'xlabel', or 'ylabel')
'style', (style dict, see below)
'mplobj' (actual mpl text object)
]
props['style'].keys() -- [
'alpha', (opacity of text)
'fontsize', (size in points of text)
'color', (hex color)
'halign', (horizontal alignment, 'left', 'center', or 'right')
'valign', (vertical alignment, 'baseline', 'center', or 'top')
'rotation',
'zorder', (precedence of text when stacked with other objs)
]
"""
self.msg += " Attempting to draw a title\n"
if len(self.mpl_fig.axes) > 1:
self.msg += (
" More than one subplot, adding title as " "annotation\n"
)
x_px, y_px = props["mplobj"].get_transform().transform(props["position"])
x, y = mpltools.display_to_paper(x_px, y_px, self.plotly_fig["layout"])
annotation = go.layout.Annotation(
text=props["text"],
font=go.layout.annotation.Font(
color=props["style"]["color"], size=props["style"]["fontsize"]
),
xref="paper",
yref="paper",
x=x,
y=y,
xanchor="center",
yanchor="bottom",
showarrow=False, # no arrow for a title!
)
self.plotly_fig["layout"]["annotations"] += (annotation,)
else:
self.msg += (
" Only one subplot found, adding as a " "plotly title\n"
)
self.plotly_fig["layout"]["title"] = props["text"]
titlefont = dict(
size=props["style"]["fontsize"], color=props["style"]["color"]
)
self.plotly_fig["layout"]["titlefont"] = titlefont
def draw_xlabel(self, **props):
"""Add an xaxis label to the current subplot in layout dictionary.
props.keys() -- [
'text', (actual content string, not the text obj)
'position', (an x, y pair, not an mpl Bbox)
'coordinates', ('data', 'axes', 'figure', 'display')
'text_type', ('title', 'xlabel', or 'ylabel')
'style', (style dict, see below)
'mplobj' (actual mpl text object)
]
props['style'].keys() -- [
'alpha', (opacity of text)
'fontsize', (size in points of text)
'color', (hex color)
'halign', (horizontal alignment, 'left', 'center', or 'right')
'valign', (vertical alignment, 'baseline', 'center', or 'top')
'rotation',
'zorder', (precedence of text when stacked with other objs)
]
"""
self.msg += " Adding xlabel\n"
axis_key = "xaxis{0}".format(self.axis_ct)
self.plotly_fig["layout"][axis_key]["title"] = str(props["text"])
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
def draw_ylabel(self, **props):
"""Add a yaxis label to the current subplot in layout dictionary.
props.keys() -- [
'text', (actual content string, not the text obj)
'position', (an x, y pair, not an mpl Bbox)
'coordinates', ('data', 'axes', 'figure', 'display')
'text_type', ('title', 'xlabel', or 'ylabel')
'style', (style dict, see below)
'mplobj' (actual mpl text object)
]
props['style'].keys() -- [
'alpha', (opacity of text)
'fontsize', (size in points of text)
'color', (hex color)
'halign', (horizontal alignment, 'left', 'center', or 'right')
'valign', (vertical alignment, 'baseline', 'center', or 'top')
'rotation',
'zorder', (precedence of text when stacked with other objs)
]
"""
self.msg += " Adding ylabel\n"
axis_key = "yaxis{0}".format(self.axis_ct)
self.plotly_fig["layout"][axis_key]["title"] = props["text"]
titlefont = dict(size=props["style"]["fontsize"], color=props["style"]["color"])
self.plotly_fig["layout"][axis_key]["titlefont"] = titlefont
def resize(self):
"""Revert figure layout to allow plotly to resize.
By default, PlotlyRenderer tries its hardest to precisely mimic an
mpl figure. However, plotly is pretty good with aesthetics. By
running PlotlyRenderer.resize(), layout parameters are deleted. This
lets plotly choose them instead of mpl.
"""
self.msg += "Resizing figure, deleting keys from layout\n"
for key in ["width", "height", "autosize", "margin"]:
try:
del self.plotly_fig["layout"][key]
except (KeyError, AttributeError):
pass
def strip_style(self):
self.msg += "Stripping mpl style is no longer supported\n"