first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions
+178
View File
@@ -0,0 +1,178 @@
"""
https://plot.ly/python/
Plotly's Python API allows users to programmatically access Plotly's
server resources.
This package is organized as follows:
Subpackages:
- plotly: all functionality that requires access to Plotly's servers
- graph_objs: objects for designing figures and visualizing data
- matplotlylib: tools to convert matplotlib figures
Modules:
- tools: some helpful tools that do not require access to Plotly's servers
- utils: functions that you probably won't need, but that subpackages use
- version: holds the current API version
- exceptions: defines our custom exception classes
"""
from __future__ import absolute_import
import sys
from typing import TYPE_CHECKING
from _plotly_utils.importers import relative_import
if sys.version_info < (3, 7) or TYPE_CHECKING:
from plotly import (
graph_objs,
tools,
utils,
offline,
colors,
io,
data,
)
from plotly.version import __version__
__all__ = [
"graph_objs",
"tools",
"utils",
"offline",
"colors",
"io",
"data",
"__version__",
]
# Set default template (for >= 3.7 this is done in ploty/io/__init__.py)
from plotly.io import templates
templates._default = "plotly"
else:
__all__, __getattr__, __dir__ = relative_import(
__name__,
[
".graph_objs",
".graph_objects",
".tools",
".utils",
".offline",
".colors",
".io",
".data",
],
[".version.__version__"],
)
def plot(data_frame, kind, **kwargs):
"""
Pandas plotting backend function, not meant to be called directly.
To activate, set pandas.options.plotting.backend="plotly"
See https://github.com/pandas-dev/pandas/blob/master/pandas/plotting/__init__.py
"""
from .express import (
scatter,
line,
area,
bar,
box,
histogram,
violin,
strip,
funnel,
density_contour,
density_heatmap,
imshow,
)
if kind == "scatter":
new_kwargs = {k: kwargs[k] for k in kwargs if k not in ["s", "c"]}
return scatter(data_frame, **new_kwargs)
if kind == "line":
return line(data_frame, **kwargs)
if kind == "area":
return area(data_frame, **kwargs)
if kind == "bar":
return bar(data_frame, **kwargs)
if kind == "barh":
return bar(data_frame, orientation="h", **kwargs)
if kind == "box":
new_kwargs = {k: kwargs[k] for k in kwargs if k not in ["by"]}
return box(data_frame, **new_kwargs)
if kind in ["hist", "histogram"]:
new_kwargs = {k: kwargs[k] for k in kwargs if k not in ["by", "bins"]}
return histogram(data_frame, **new_kwargs)
if kind == "violin":
return violin(data_frame, **kwargs)
if kind == "strip":
return strip(data_frame, **kwargs)
if kind == "funnel":
return funnel(data_frame, **kwargs)
if kind == "density_contour":
return density_contour(data_frame, **kwargs)
if kind == "density_heatmap":
return density_heatmap(data_frame, **kwargs)
if kind == "imshow":
return imshow(data_frame, **kwargs)
if kind == "heatmap":
raise ValueError(
"kind='heatmap' not supported plotting.backend='plotly'. "
"Please use kind='imshow' or kind='density_heatmap'."
)
raise NotImplementedError(
"kind='%s' not yet supported for plotting.backend='plotly'" % kind
)
def boxplot_frame(data_frame, **kwargs):
"""
Pandas plotting backend function, not meant to be called directly.
To activate, set pandas.options.plotting.backend="plotly"
See https://github.com/pandas-dev/pandas/blob/master/pandas/plotting/__init__.py
"""
from .express import box
skip = ["by", "column", "ax", "fontsize", "rot", "grid", "figsize", "layout"]
skip += ["return_type"]
new_kwargs = {k: kwargs[k] for k in kwargs if k not in skip}
return box(data_frame, **new_kwargs)
def hist_frame(data_frame, **kwargs):
"""
Pandas plotting backend function, not meant to be called directly.
To activate, set pandas.options.plotting.backend="plotly"
See https://github.com/pandas-dev/pandas/blob/master/pandas/plotting/__init__.py
"""
from .express import histogram
skip = ["column", "by", "grid", "xlabelsize", "xrot", "ylabelsize", "yrot"]
skip += ["ax", "sharex", "sharey", "figsize", "layout", "bins", "legend"]
new_kwargs = {k: kwargs[k] for k in kwargs if k not in skip}
return histogram(data_frame, **new_kwargs)
def hist_series(data_frame, **kwargs):
"""
Pandas plotting backend function, not meant to be called directly.
To activate, set pandas.options.plotting.backend="plotly"
See https://github.com/pandas-dev/pandas/blob/master/pandas/plotting/__init__.py
"""
from .express import histogram
skip = ["by", "grid", "xlabelsize", "xrot", "ylabelsize", "yrot", "ax"]
skip += ["figsize", "bins", "legend"]
new_kwargs = {k: kwargs[k] for k in kwargs if k not in skip}
return histogram(data_frame, **new_kwargs)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,21 @@
# This file was generated by 'versioneer.py' (0.21) from
# revision-control system data, or from the parent directory name of an
# unpacked source archive. Distribution tarballs contain a pre-generated copy
# of this file.
import json
version_json = '''
{
"date": "2022-05-09T20:40:15-0400",
"dirty": false,
"error": null,
"full-revisionid": "eca5fe62f9262478bacc5895db3a053116cac393",
"version": "5.8.0"
}
''' # END VERSION_JSON
def get_versions():
return json.loads(version_json)
@@ -0,0 +1,5 @@
# This file is generated by the updateplotlywidgetversion setup.py command
# for automated dev builds
#
# It is edited by hand prior to official releases
__frontend_version__ = "^5.8.0"
@@ -0,0 +1,54 @@
from _plotly_utils.basevalidators import EnumeratedValidator, NumberValidator
class EasingValidator(EnumeratedValidator):
def __init__(self, plotly_name="easing", parent_name="batch_animate", **_):
super(EasingValidator, self).__init__(
plotly_name=plotly_name,
parent_name=parent_name,
values=[
"linear",
"quad",
"cubic",
"sin",
"exp",
"circle",
"elastic",
"back",
"bounce",
"linear-in",
"quad-in",
"cubic-in",
"sin-in",
"exp-in",
"circle-in",
"elastic-in",
"back-in",
"bounce-in",
"linear-out",
"quad-out",
"cubic-out",
"sin-out",
"exp-out",
"circle-out",
"elastic-out",
"back-out",
"bounce-out",
"linear-in-out",
"quad-in-out",
"cubic-in-out",
"sin-in-out",
"exp-in-out",
"circle-in-out",
"elastic-in-out",
"back-in-out",
"bounce-in-out",
],
)
class DurationValidator(NumberValidator):
def __init__(self, plotly_name="duration"):
super(DurationValidator, self).__init__(
plotly_name=plotly_name, parent_name="batch_animate", min=0
)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,988 @@
import uuid
from importlib import import_module
import os
import numbers
try:
from urllib import parse
except ImportError:
from urlparse import urlparse as parse
import ipywidgets as widgets
from traitlets import List, Unicode, Dict, observe, Integer
from .basedatatypes import BaseFigure, BasePlotlyType
from .callbacks import BoxSelector, LassoSelector, InputDeviceState, Points
from .serializers import custom_serializers
from .version import __frontend_version__
@widgets.register
class BaseFigureWidget(BaseFigure, widgets.DOMWidget):
"""
Base class for FigureWidget. The FigureWidget class is code-generated as a
subclass
"""
# Widget Traits
# -------------
# Widget traitlets are automatically synchronized with the FigureModel
# JavaScript object
_view_name = Unicode("FigureView").tag(sync=True)
_view_module = Unicode("jupyterlab-plotly").tag(sync=True)
_view_module_version = Unicode(__frontend_version__).tag(sync=True)
_model_name = Unicode("FigureModel").tag(sync=True)
_model_module = Unicode("jupyterlab-plotly").tag(sync=True)
_model_module_version = Unicode(__frontend_version__).tag(sync=True)
# ### _data and _layout ###
# These properties store the current state of the traces and
# layout as JSON-style dicts. These dicts do not store any subclasses of
# `BasePlotlyType`
#
# Note: These are only automatically synced with the frontend on full
# assignment, not on mutation. We use this fact to only directly sync
# them to the front-end on FigureWidget construction. All other updates
# are made using mutation, and they are manually synced to the frontend
# using the relayout/restyle/update/etc. messages.
_layout = Dict().tag(sync=True, **custom_serializers)
_data = List().tag(sync=True, **custom_serializers)
_config = Dict().tag(sync=True, **custom_serializers)
# ### Python -> JS message properties ###
# These properties are used to send messages from Python to the
# frontend. Messages are sent by assigning the message contents to the
# appropriate _py2js_* property and then immediatly assigning None to the
# property.
#
# See JSDoc comments in the FigureModel class in js/src/Figure.js for
# detailed descriptions of the messages.
_py2js_addTraces = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_restyle = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_relayout = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_update = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_animate = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_deleteTraces = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_moveTraces = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_py2js_removeLayoutProps = Dict(allow_none=True).tag(
sync=True, **custom_serializers
)
_py2js_removeTraceProps = Dict(allow_none=True).tag(sync=True, **custom_serializers)
# ### JS -> Python message properties ###
# These properties are used to receive messages from the frontend.
# Messages are received by defining methods that observe changes to these
# properties. Receive methods are named `_handler_js2py_*` where '*' is
# the name of the corresponding message property. Receive methods are
# responsible for setting the message property to None after retreiving
# the message data.
#
# See JSDoc comments in the FigureModel class in js/src/Figure.js for
# detailed descriptions of the messages.
_js2py_traceDeltas = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_js2py_layoutDelta = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_js2py_restyle = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_js2py_relayout = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_js2py_update = Dict(allow_none=True).tag(sync=True, **custom_serializers)
_js2py_pointsCallback = Dict(allow_none=True).tag(sync=True, **custom_serializers)
# ### Message tracking properties ###
# The _last_layout_edit_id and _last_trace_edit_id properties are used
# to keep track of the edit id of the message that most recently
# requested an update to the Figures layout or traces respectively.
#
# We track this information because we don't want to update the Figure's
# default layout/trace properties (_layout_defaults, _data_defaults)
# while edits are in process. This can lead to inconsistent property
# states.
_last_layout_edit_id = Integer(0).tag(sync=True)
_last_trace_edit_id = Integer(0).tag(sync=True)
_set_trace_uid = True
_allow_disable_validation = False
# Constructor
# -----------
def __init__(
self, data=None, layout=None, frames=None, skip_invalid=False, **kwargs
):
# Call superclass constructors
# ----------------------------
# Note: We rename layout to layout_plotly because to deconflict it
# with the `layout` constructor parameter of the `widgets.DOMWidget`
# ipywidgets class
super(BaseFigureWidget, self).__init__(
data=data,
layout_plotly=layout,
frames=frames,
skip_invalid=skip_invalid,
**kwargs,
)
# Validate Frames
# ---------------
# Frames are not supported by figure widget
if self._frame_objs:
BaseFigureWidget._display_frames_error()
# Message States
# --------------
# ### Layout ###
# _last_layout_edit_id is described above
self._last_layout_edit_id = 0
# _layout_edit_in_process is set to True if there are layout edit
# operations that have been sent to the frontend that haven't
# completed yet.
self._layout_edit_in_process = False
# _waiting_edit_callbacks is a list of callback functions that
# should be executed as soon as all pending edit operations are
# completed
self._waiting_edit_callbacks = []
# ### Trace ###
# _last_trace_edit_id: described above
self._last_trace_edit_id = 0
# _trace_edit_in_process is set to True if there are trace edit
# operations that have been sent to the frontend that haven't
# completed yet.
self._trace_edit_in_process = False
# View count
# ----------
# ipywidget property that stores the number of active frontend
# views of this widget
self._view_count = 0
# Python -> JavaScript Messages
# -----------------------------
def _send_relayout_msg(self, layout_data, source_view_id=None):
"""
Send Plotly.relayout message to the frontend
Parameters
----------
layout_data : dict
Plotly.relayout layout data
source_view_id : str
UID of view that triggered this relayout operation
(e.g. By the user clicking 'zoom' in the toolbar). None if the
operation was not triggered by a frontend view
"""
# Increment layout edit messages IDs
# ----------------------------------
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
# Build message
# -------------
msg_data = {
"relayout_data": layout_data,
"layout_edit_id": layout_edit_id,
"source_view_id": source_view_id,
}
# Send message
# ------------
self._py2js_relayout = msg_data
self._py2js_relayout = None
def _send_restyle_msg(self, restyle_data, trace_indexes=None, source_view_id=None):
"""
Send Plotly.restyle message to the frontend
Parameters
----------
restyle_data : dict
Plotly.restyle restyle data
trace_indexes : list[int]
List of trace indexes that the restyle operation
applies to
source_view_id : str
UID of view that triggered this restyle operation
(e.g. By the user clicking the legend to hide a trace).
None if the operation was not triggered by a frontend view
"""
# Validate / normalize inputs
# ---------------------------
trace_indexes = self._normalize_trace_indexes(trace_indexes)
# Increment layout/trace edit message IDs
# ---------------------------------------
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
trace_edit_id = self._last_trace_edit_id + 1
self._last_trace_edit_id = trace_edit_id
self._trace_edit_in_process = True
# Build message
# -------------
restyle_msg = {
"restyle_data": restyle_data,
"restyle_traces": trace_indexes,
"trace_edit_id": trace_edit_id,
"layout_edit_id": layout_edit_id,
"source_view_id": source_view_id,
}
# Send message
# ------------
self._py2js_restyle = restyle_msg
self._py2js_restyle = None
def _send_addTraces_msg(self, new_traces_data):
"""
Send Plotly.addTraces message to the frontend
Parameters
----------
new_traces_data : list[dict]
List of trace data for new traces as accepted by Plotly.addTraces
"""
# Increment layout/trace edit message IDs
# ---------------------------------------
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
trace_edit_id = self._last_trace_edit_id + 1
self._last_trace_edit_id = trace_edit_id
self._trace_edit_in_process = True
# Build message
# -------------
add_traces_msg = {
"trace_data": new_traces_data,
"trace_edit_id": trace_edit_id,
"layout_edit_id": layout_edit_id,
}
# Send message
# ------------
self._py2js_addTraces = add_traces_msg
self._py2js_addTraces = None
def _send_moveTraces_msg(self, current_inds, new_inds):
"""
Send Plotly.moveTraces message to the frontend
Parameters
----------
current_inds : list[int]
List of current trace indexes
new_inds : list[int]
List of new trace indexes
"""
# Build message
# -------------
move_msg = {"current_trace_inds": current_inds, "new_trace_inds": new_inds}
# Send message
# ------------
self._py2js_moveTraces = move_msg
self._py2js_moveTraces = None
def _send_update_msg(
self, restyle_data, relayout_data, trace_indexes=None, source_view_id=None
):
"""
Send Plotly.update message to the frontend
Parameters
----------
restyle_data : dict
Plotly.update restyle data
relayout_data : dict
Plotly.update relayout data
trace_indexes : list[int]
List of trace indexes that the update operation applies to
source_view_id : str
UID of view that triggered this update operation
(e.g. By the user clicking a button).
None if the operation was not triggered by a frontend view
"""
# Validate / normalize inputs
# ---------------------------
trace_indexes = self._normalize_trace_indexes(trace_indexes)
# Increment layout/trace edit message IDs
# ---------------------------------------
trace_edit_id = self._last_trace_edit_id + 1
self._last_trace_edit_id = trace_edit_id
self._trace_edit_in_process = True
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
# Build message
# -------------
update_msg = {
"style_data": restyle_data,
"layout_data": relayout_data,
"style_traces": trace_indexes,
"trace_edit_id": trace_edit_id,
"layout_edit_id": layout_edit_id,
"source_view_id": source_view_id,
}
# Send message
# ------------
self._py2js_update = update_msg
self._py2js_update = None
def _send_animate_msg(
self, styles_data, relayout_data, trace_indexes, animation_opts
):
"""
Send Plotly.update message to the frontend
Note: there is no source_view_id parameter because animations
triggered by the fontend are not currently supported
Parameters
----------
styles_data : list[dict]
Plotly.animate styles data
relayout_data : dict
Plotly.animate relayout data
trace_indexes : list[int]
List of trace indexes that the animate operation applies to
"""
# Validate / normalize inputs
# ---------------------------
trace_indexes = self._normalize_trace_indexes(trace_indexes)
# Increment layout/trace edit message IDs
# ---------------------------------------
trace_edit_id = self._last_trace_edit_id + 1
self._last_trace_edit_id = trace_edit_id
self._trace_edit_in_process = True
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
# Build message
# -------------
animate_msg = {
"style_data": styles_data,
"layout_data": relayout_data,
"style_traces": trace_indexes,
"animation_opts": animation_opts,
"trace_edit_id": trace_edit_id,
"layout_edit_id": layout_edit_id,
"source_view_id": None,
}
# Send message
# ------------
self._py2js_animate = animate_msg
self._py2js_animate = None
def _send_deleteTraces_msg(self, delete_inds):
"""
Send Plotly.deleteTraces message to the frontend
Parameters
----------
delete_inds : list[int]
List of trace indexes of traces to delete
"""
# Increment layout/trace edit message IDs
# ---------------------------------------
trace_edit_id = self._last_trace_edit_id + 1
self._last_trace_edit_id = trace_edit_id
self._trace_edit_in_process = True
layout_edit_id = self._last_layout_edit_id + 1
self._last_layout_edit_id = layout_edit_id
self._layout_edit_in_process = True
# Build message
# -------------
delete_msg = {
"delete_inds": delete_inds,
"layout_edit_id": layout_edit_id,
"trace_edit_id": trace_edit_id,
}
# Send message
# ------------
self._py2js_deleteTraces = delete_msg
self._py2js_deleteTraces = None
# JavaScript -> Python Messages
# -----------------------------
@observe("_js2py_traceDeltas")
def _handler_js2py_traceDeltas(self, change):
"""
Process trace deltas message from the frontend
"""
# Receive message
# ---------------
msg_data = change["new"]
if not msg_data:
self._js2py_traceDeltas = None
return
trace_deltas = msg_data["trace_deltas"]
trace_edit_id = msg_data["trace_edit_id"]
# Apply deltas
# ------------
# We only apply the deltas if this message corresponds to the most
# recent trace edit operation
if trace_edit_id == self._last_trace_edit_id:
# ### Loop over deltas ###
for delta in trace_deltas:
# #### Find existing trace for uid ###
trace_uid = delta["uid"]
trace_uids = [trace.uid for trace in self.data]
trace_index = trace_uids.index(trace_uid)
uid_trace = self.data[trace_index]
# #### Transform defaults to delta ####
delta_transform = BaseFigureWidget._transform_data(
uid_trace._prop_defaults, delta
)
# #### Remove overlapping properties ####
# If a property is present in both _props and _prop_defaults
# then we remove the copy from _props
remove_props = self._remove_overlapping_props(
uid_trace._props, uid_trace._prop_defaults
)
# #### Notify frontend model of property removal ####
if remove_props:
remove_trace_props_msg = {
"remove_trace": trace_index,
"remove_props": remove_props,
}
self._py2js_removeTraceProps = remove_trace_props_msg
self._py2js_removeTraceProps = None
# #### Dispatch change callbacks ####
self._dispatch_trace_change_callbacks(delta_transform, [trace_index])
# ### Trace edits no longer in process ###
self._trace_edit_in_process = False
# ### Call any waiting trace edit callbacks ###
if not self._layout_edit_in_process:
while self._waiting_edit_callbacks:
self._waiting_edit_callbacks.pop()()
self._js2py_traceDeltas = None
@observe("_js2py_layoutDelta")
def _handler_js2py_layoutDelta(self, change):
"""
Process layout delta message from the frontend
"""
# Receive message
# ---------------
msg_data = change["new"]
if not msg_data:
self._js2py_layoutDelta = None
return
layout_delta = msg_data["layout_delta"]
layout_edit_id = msg_data["layout_edit_id"]
# Apply delta
# -----------
# We only apply the delta if this message corresponds to the most
# recent layout edit operation
if layout_edit_id == self._last_layout_edit_id:
# ### Transform defaults to delta ###
delta_transform = BaseFigureWidget._transform_data(
self._layout_defaults, layout_delta
)
# ### Remove overlapping properties ###
# If a property is present in both _layout and _layout_defaults
# then we remove the copy from _layout
removed_props = self._remove_overlapping_props(
self._layout, self._layout_defaults
)
# ### Notify frontend model of property removal ###
if removed_props:
remove_props_msg = {"remove_props": removed_props}
self._py2js_removeLayoutProps = remove_props_msg
self._py2js_removeLayoutProps = None
# ### Create axis objects ###
# For example, when a SPLOM trace is created the layout defaults
# may include axes that weren't explicitly defined by the user.
for proppath in delta_transform:
prop = proppath[0]
match = self.layout._subplot_re_match(prop)
if match and prop not in self.layout:
# We need to create a subplotid object
self.layout[prop] = {}
# ### Dispatch change callbacks ###
self._dispatch_layout_change_callbacks(delta_transform)
# ### Layout edits no longer in process ###
self._layout_edit_in_process = False
# ### Call any waiting layout edit callbacks ###
if not self._trace_edit_in_process:
while self._waiting_edit_callbacks:
self._waiting_edit_callbacks.pop()()
self._js2py_layoutDelta = None
@observe("_js2py_restyle")
def _handler_js2py_restyle(self, change):
"""
Process Plotly.restyle message from the frontend
"""
# Receive message
# ---------------
restyle_msg = change["new"]
if not restyle_msg:
self._js2py_restyle = None
return
style_data = restyle_msg["style_data"]
style_traces = restyle_msg["style_traces"]
source_view_id = restyle_msg["source_view_id"]
# Perform restyle
# ---------------
self.plotly_restyle(
restyle_data=style_data,
trace_indexes=style_traces,
source_view_id=source_view_id,
)
self._js2py_restyle = None
@observe("_js2py_update")
def _handler_js2py_update(self, change):
"""
Process Plotly.update message from the frontend
"""
# Receive message
# ---------------
update_msg = change["new"]
if not update_msg:
self._js2py_update = None
return
style = update_msg["style_data"]
trace_indexes = update_msg["style_traces"]
layout = update_msg["layout_data"]
source_view_id = update_msg["source_view_id"]
# Perform update
# --------------
self.plotly_update(
restyle_data=style,
relayout_data=layout,
trace_indexes=trace_indexes,
source_view_id=source_view_id,
)
self._js2py_update = None
@observe("_js2py_relayout")
def _handler_js2py_relayout(self, change):
"""
Process Plotly.relayout message from the frontend
"""
# Receive message
# ---------------
relayout_msg = change["new"]
if not relayout_msg:
self._js2py_relayout = None
return
relayout_data = relayout_msg["relayout_data"]
source_view_id = relayout_msg["source_view_id"]
if "lastInputTime" in relayout_data:
# Remove 'lastInputTime'. Seems to be an internal plotly
# property that is introduced for some plot types, but it is not
# actually a property in the schema
relayout_data.pop("lastInputTime")
# Perform relayout
# ----------------
self.plotly_relayout(relayout_data=relayout_data, source_view_id=source_view_id)
self._js2py_relayout = None
@observe("_js2py_pointsCallback")
def _handler_js2py_pointsCallback(self, change):
"""
Process points callback message from the frontend
"""
# Receive message
# ---------------
callback_data = change["new"]
if not callback_data:
self._js2py_pointsCallback = None
return
# Get event type
# --------------
event_type = callback_data["event_type"]
# Build Selector Object
# ---------------------
if callback_data.get("selector", None):
selector_data = callback_data["selector"]
selector_type = selector_data["type"]
selector_state = selector_data["selector_state"]
if selector_type == "box":
selector = BoxSelector(**selector_state)
elif selector_type == "lasso":
selector = LassoSelector(**selector_state)
else:
raise ValueError("Unsupported selector type: %s" % selector_type)
else:
selector = None
# Build Input Device State Object
# -------------------------------
if callback_data.get("device_state", None):
device_state_data = callback_data["device_state"]
state = InputDeviceState(**device_state_data)
else:
state = None
# Build Trace Points Dictionary
# -----------------------------
points_data = callback_data["points"]
trace_points = {
trace_ind: {
"point_inds": [],
"xs": [],
"ys": [],
"trace_name": self._data_objs[trace_ind].name,
"trace_index": trace_ind,
}
for trace_ind in range(len(self._data_objs))
}
for x, y, point_ind, trace_ind in zip(
points_data["xs"],
points_data["ys"],
points_data["point_indexes"],
points_data["trace_indexes"],
):
trace_dict = trace_points[trace_ind]
trace_dict["xs"].append(x)
trace_dict["ys"].append(y)
trace_dict["point_inds"].append(point_ind)
# Dispatch callbacks
# ------------------
for trace_ind, trace_points_data in trace_points.items():
points = Points(**trace_points_data)
trace = self.data[trace_ind]
if event_type == "plotly_click":
trace._dispatch_on_click(points, state)
elif event_type == "plotly_hover":
trace._dispatch_on_hover(points, state)
elif event_type == "plotly_unhover":
trace._dispatch_on_unhover(points, state)
elif event_type == "plotly_selected":
trace._dispatch_on_selection(points, selector)
elif event_type == "plotly_deselect":
trace._dispatch_on_deselect(points)
self._js2py_pointsCallback = None
# Display
# -------
def _ipython_display_(self):
"""
Handle rich display of figures in ipython contexts
"""
# Override BaseFigure's display to make sure we display the widget version
widgets.DOMWidget._ipython_display_(self)
# Callbacks
# ---------
def on_edits_completed(self, fn):
"""
Register a function to be called after all pending trace and layout
edit operations have completed
If there are no pending edit operations then function is called
immediately
Parameters
----------
fn : callable
Function of zero arguments to be called when all pending edit
operations have completed
"""
if self._layout_edit_in_process or self._trace_edit_in_process:
self._waiting_edit_callbacks.append(fn)
else:
fn()
# Validate No Frames
# ------------------
@property
def frames(self):
# Note: This property getter is identical to that of the superclass,
# but it must be included here because we're overriding the setter
# below.
return self._frame_objs
@frames.setter
def frames(self, new_frames):
if new_frames:
BaseFigureWidget._display_frames_error()
@staticmethod
def _display_frames_error():
"""
Display an informative error when user attempts to set frames on a
FigureWidget
Raises
------
ValueError
always
"""
msg = """
Frames are not supported by the plotly.graph_objs.FigureWidget class.
Note: Frames are supported by the plotly.graph_objs.Figure class"""
raise ValueError(msg)
# Static Helpers
# --------------
@staticmethod
def _remove_overlapping_props(input_data, delta_data, prop_path=()):
"""
Remove properties in input_data that are also in delta_data, and do so
recursively.
Exception: Never remove 'uid' from input_data, this property is used
to align traces
Parameters
----------
input_data : dict|list
delta_data : dict|list
Returns
-------
list[tuple[str|int]]
List of removed property path tuples
"""
# Initialize removed
# ------------------
# This is the list of path tuples to the properties that were
# removed from input_data
removed = []
# Handle dict
# -----------
if isinstance(input_data, dict):
assert isinstance(delta_data, dict)
for p, delta_val in delta_data.items():
if isinstance(delta_val, dict) or BaseFigure._is_dict_list(delta_val):
if p in input_data:
# ### Recurse ###
input_val = input_data[p]
recur_prop_path = prop_path + (p,)
recur_removed = BaseFigureWidget._remove_overlapping_props(
input_val, delta_val, recur_prop_path
)
removed.extend(recur_removed)
# Check whether the last property in input_val
# has been removed. If so, remove it entirely
if not input_val:
input_data.pop(p)
removed.append(recur_prop_path)
elif p in input_data and p != "uid":
# ### Remove property ###
input_data.pop(p)
removed.append(prop_path + (p,))
# Handle list
# -----------
elif isinstance(input_data, list):
assert isinstance(delta_data, list)
for i, delta_val in enumerate(delta_data):
if i >= len(input_data):
break
input_val = input_data[i]
if (
input_val is not None
and isinstance(delta_val, dict)
or BaseFigure._is_dict_list(delta_val)
):
# ### Recurse ###
recur_prop_path = prop_path + (i,)
recur_removed = BaseFigureWidget._remove_overlapping_props(
input_val, delta_val, recur_prop_path
)
removed.extend(recur_removed)
return removed
@staticmethod
def _transform_data(to_data, from_data, should_remove=True, relayout_path=()):
"""
Transform to_data into from_data and return relayout-style
description of the transformation
Parameters
----------
to_data : dict|list
from_data : dict|list
Returns
-------
dict
relayout-style description of the transformation
"""
# Initialize relayout data
# ------------------------
relayout_data = {}
# Handle dict
# -----------
if isinstance(to_data, dict):
# ### Validate from_data ###
if not isinstance(from_data, dict):
raise ValueError(
"Mismatched data types: {to_dict} {from_data}".format(
to_dict=to_data, from_data=from_data
)
)
# ### Add/modify properties ###
# Loop over props/vals
for from_prop, from_val in from_data.items():
# #### Handle compound vals recursively ####
if isinstance(from_val, dict) or BaseFigure._is_dict_list(from_val):
# ##### Init property value if needed #####
if from_prop not in to_data:
to_data[from_prop] = {} if isinstance(from_val, dict) else []
# ##### Transform property val recursively #####
input_val = to_data[from_prop]
relayout_data.update(
BaseFigureWidget._transform_data(
input_val,
from_val,
should_remove=should_remove,
relayout_path=relayout_path + (from_prop,),
)
)
# #### Handle simple vals directly ####
else:
if from_prop not in to_data or not BasePlotlyType._vals_equal(
to_data[from_prop], from_val
):
to_data[from_prop] = from_val
relayout_path_prop = relayout_path + (from_prop,)
relayout_data[relayout_path_prop] = from_val
# ### Remove properties ###
if should_remove:
for remove_prop in set(to_data.keys()).difference(
set(from_data.keys())
):
to_data.pop(remove_prop)
# Handle list
# -----------
elif isinstance(to_data, list):
# ### Validate from_data ###
if not isinstance(from_data, list):
raise ValueError(
"Mismatched data types: to_data: {to_data} {from_data}".format(
to_data=to_data, from_data=from_data
)
)
# ### Add/modify properties ###
# Loop over indexes / elements
for i, from_val in enumerate(from_data):
# #### Initialize element if needed ####
if i >= len(to_data):
to_data.append(None)
input_val = to_data[i]
# #### Handle compound element recursively ####
if input_val is not None and (
isinstance(from_val, dict) or BaseFigure._is_dict_list(from_val)
):
relayout_data.update(
BaseFigureWidget._transform_data(
input_val,
from_val,
should_remove=should_remove,
relayout_path=relayout_path + (i,),
)
)
# #### Handle simple elements directly ####
else:
if not BasePlotlyType._vals_equal(to_data[i], from_val):
to_data[i] = from_val
relayout_data[relayout_path + (i,)] = from_val
return relayout_data
+299
View File
@@ -0,0 +1,299 @@
from __future__ import absolute_import
from plotly.utils import _list_repr_elided
class InputDeviceState:
def __init__(
self, ctrl=None, alt=None, shift=None, meta=None, button=None, buttons=None, **_
):
self._ctrl = ctrl
self._alt = alt
self._meta = meta
self._shift = shift
self._button = button
self._buttons = buttons
def __repr__(self):
return """\
InputDeviceState(
ctrl={ctrl},
alt={alt},
shift={shift},
meta={meta},
button={button},
buttons={buttons})""".format(
ctrl=repr(self.ctrl),
alt=repr(self.alt),
meta=repr(self.meta),
shift=repr(self.shift),
button=repr(self.button),
buttons=repr(self.buttons),
)
@property
def alt(self):
"""
Whether alt key pressed
Returns
-------
bool
"""
return self._alt
@property
def ctrl(self):
"""
Whether ctrl key pressed
Returns
-------
bool
"""
return self._ctrl
@property
def shift(self):
"""
Whether shift key pressed
Returns
-------
bool
"""
return self._shift
@property
def meta(self):
"""
Whether meta key pressed
Returns
-------
bool
"""
return self._meta
@property
def button(self):
"""
Integer code for the button that was pressed on the mouse to trigger
the event
- 0: Main button pressed, usually the left button or the
un-initialized state
- 1: Auxiliary button pressed, usually the wheel button or the middle
button (if present)
- 2: Secondary button pressed, usually the right button
- 3: Fourth button, typically the Browser Back button
- 4: Fifth button, typically the Browser Forward button
Returns
-------
int
"""
return self._button
@property
def buttons(self):
"""
Integer code for which combination of buttons are pressed on the
mouse when the event is triggered.
- 0: No button or un-initialized
- 1: Primary button (usually left)
- 2: Secondary button (usually right)
- 4: Auxilary button (usually middle or mouse wheel button)
- 8: 4th button (typically the "Browser Back" button)
- 16: 5th button (typically the "Browser Forward" button)
Combinations of buttons are represented as the decimal form of the
bitmask of the values above.
For example, pressing both the primary (1) and auxilary (4) buttons
will result in a code of 5
Returns
-------
int
"""
return self._buttons
class Points:
def __init__(self, point_inds=[], xs=[], ys=[], trace_name=None, trace_index=None):
self._point_inds = point_inds
self._xs = xs
self._ys = ys
self._trace_name = trace_name
self._trace_index = trace_index
def __repr__(self):
return """\
Points(point_inds={point_inds},
xs={xs},
ys={ys},
trace_name={trace_name},
trace_index={trace_index})""".format(
point_inds=_list_repr_elided(
self.point_inds, indent=len("Points(point_inds=")
),
xs=_list_repr_elided(self.xs, indent=len(" xs=")),
ys=_list_repr_elided(self.ys, indent=len(" ys=")),
trace_name=repr(self.trace_name),
trace_index=repr(self.trace_index),
)
@property
def point_inds(self):
"""
List of selected indexes into the trace's points
Returns
-------
list[int]
"""
return self._point_inds
@property
def xs(self):
"""
List of x-coordinates of selected points
Returns
-------
list[float]
"""
return self._xs
@property
def ys(self):
"""
List of y-coordinates of selected points
Returns
-------
list[float]
"""
return self._ys
@property
def trace_name(self):
"""
Name of the trace
Returns
-------
str
"""
return self._trace_name
@property
def trace_index(self):
"""
Index of the trace in the figure
Returns
-------
int
"""
return self._trace_index
class BoxSelector:
def __init__(self, xrange=None, yrange=None, **_):
self._type = "box"
self._xrange = xrange
self._yrange = yrange
def __repr__(self):
return """\
BoxSelector(xrange={xrange},
yrange={yrange})""".format(
xrange=self.xrange, yrange=self.yrange
)
@property
def type(self):
"""
The selector's type
Returns
-------
str
"""
return self._type
@property
def xrange(self):
"""
x-axis range extents of the box selection
Returns
-------
(float, float)
"""
return self._xrange
@property
def yrange(self):
"""
y-axis range extents of the box selection
Returns
-------
(float, float)
"""
return self._yrange
class LassoSelector:
def __init__(self, xs=None, ys=None, **_):
self._type = "lasso"
self._xs = xs
self._ys = ys
def __repr__(self):
return """\
LassoSelector(xs={xs},
ys={ys})""".format(
xs=_list_repr_elided(self.xs, indent=len("LassoSelector(xs=")),
ys=_list_repr_elided(self.ys, indent=len(" ys=")),
)
@property
def type(self):
"""
The selector's type
Returns
-------
str
"""
return self._type
@property
def xs(self):
"""
list of x-axis coordinates of each point in the lasso selection
boundary
Returns
-------
list[float]
"""
return self._xs
@property
def ys(self):
"""
list of y-axis coordinates of each point in the lasso selection
boundary
Returns
-------
list[float]
"""
return self._ys
@@ -0,0 +1,50 @@
"""For a list of colors available in `plotly.colors`, please see
* the `tutorial on discrete color sequences <https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express>`_
* the `list of built-in continuous color scales <https://plotly.com/python/builtin-colorscales/>`_
* the `tutorial on continuous colors <https://plotly.com/python/colorscales/>`_
Color scales and sequences are available within the following namespaces
* cyclical
* diverging
* qualitative
* sequential
"""
from __future__ import absolute_import
from _plotly_utils.colors import * # noqa: F401
__all__ = [
"named_colorscales",
"cyclical",
"diverging",
"sequential",
"qualitative",
"colorbrewer",
"carto",
"cmocean",
"color_parser",
"colorscale_to_colors",
"colorscale_to_scale",
"convert_colors_to_same_type",
"convert_colorscale_to_rgb",
"convert_dict_colors_to_same_type",
"convert_to_RGB_255",
"find_intermediate_color",
"hex_to_rgb",
"label_rgb",
"make_colorscale",
"n_colors",
"sample_colorscale",
"unconvert_from_RGB_255",
"unlabel_rgb",
"validate_colors",
"validate_colors_dict",
"validate_colorscale",
"validate_scale_values",
"plotlyjs",
"DEFAULT_PLOTLY_COLORS",
"PLOTLY_SCALES",
"get_colorscale",
]
+4
View File
@@ -0,0 +1,4 @@
from __future__ import absolute_import
from _plotly_future_ import _chart_studio_error
_chart_studio_error("config")
@@ -0,0 +1,25 @@
import pytest
import os
def pytest_ignore_collect(path):
# Ignored files, most of them are raising a chart studio error
ignored_paths = [
"exploding_module.py",
"chunked_requests.py",
"v2.py",
"v1.py",
"presentation_objs.py",
"widgets.py",
"dashboard_objs.py",
"grid_objs.py",
"config.py",
"presentation_objs.py",
"session.py",
]
if (
os.path.basename(str(path)) in ignored_paths
or "plotly/plotly/plotly/__init__.py" in str(path)
or "plotly/api/utils.py" in str(path)
):
return True
@@ -0,0 +1,4 @@
from __future__ import absolute_import
from _plotly_future_ import _chart_studio_error
_chart_studio_error("dashboard_objs")
@@ -0,0 +1,222 @@
"""
Built-in datasets for demonstration, educational and test purposes.
"""
def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
"""
Each row represents a country on a given year.
https://www.gapminder.org/data/
Returns:
A `pandas.DataFrame` with 1704 rows and the following columns:
`['country', 'continent', 'year', 'lifeExp', 'pop', 'gdpPercap',
'iso_alpha', 'iso_num']`.
If `datetimes` is True, the 'year' column will be a datetime column
If `centroids` is True, two new columns are added: ['centroid_lat', 'centroid_lon']
If `year` is an integer, the dataset will be filtered for that year
"""
df = _get_dataset("gapminder")
if year:
df = df[df["year"] == year]
if datetimes:
df["year"] = (df["year"].astype(str) + "-01-01").astype("datetime64[ns]")
if not centroids:
df = df.drop(["centroid_lat", "centroid_lon"], axis=1)
if pretty_names:
df.rename(
mapper=dict(
country="Country",
continent="Continent",
year="Year",
lifeExp="Life Expectancy",
gdpPercap="GDP per Capita",
pop="Population",
iso_alpha="ISO Alpha Country Code",
iso_num="ISO Numeric Country Code",
centroid_lat="Centroid Latitude",
centroid_lon="Centroid Longitude",
),
axis="columns",
inplace=True,
)
return df
def tips(pretty_names=False):
"""
Each row represents a restaurant bill.
https://vincentarelbundock.github.io/Rdatasets/doc/reshape2/tips.html
Returns:
A `pandas.DataFrame` with 244 rows and the following columns:
`['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`."""
df = _get_dataset("tips")
if pretty_names:
df.rename(
mapper=dict(
total_bill="Total Bill",
tip="Tip",
sex="Payer Gender",
smoker="Smokers at Table",
day="Day of Week",
time="Meal",
size="Party Size",
),
axis="columns",
inplace=True,
)
return df
def iris():
"""
Each row represents a flower.
https://en.wikipedia.org/wiki/Iris_flower_data_set
Returns:
A `pandas.DataFrame` with 150 rows and the following columns:
`['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', 'species_id']`."""
return _get_dataset("iris")
def wind():
"""
Each row represents a level of wind intensity in a cardinal direction, and its frequency.
Returns:
A `pandas.DataFrame` with 128 rows and the following columns:
`['direction', 'strength', 'frequency']`."""
return _get_dataset("wind")
def election():
"""
Each row represents voting results for an electoral district in the 2013 Montreal
mayoral election.
Returns:
A `pandas.DataFrame` with 58 rows and the following columns:
`['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result', 'district_id']`."""
return _get_dataset("election")
def election_geojson():
"""
Each feature represents an electoral district in the 2013 Montreal mayoral election.
Returns:
A GeoJSON-formatted `dict` with 58 polygon or multi-polygon features whose `id`
is an electoral district numerical ID and whose `district` property is the ID and
district name."""
import gzip
import json
import os
path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"package_data",
"datasets",
"election.geojson.gz",
)
with gzip.GzipFile(path, "r") as f:
result = json.loads(f.read().decode("utf-8"))
return result
def carshare():
"""
Each row represents the availability of car-sharing services near the centroid of a zone
in Montreal over a month-long period.
Returns:
A `pandas.DataFrame` with 249 rows and the following columns:
`['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`."""
return _get_dataset("carshare")
def stocks(indexed=False, datetimes=False):
"""
Each row in this wide dataset represents closing prices from 6 tech stocks in 2018/2019.
Returns:
A `pandas.DataFrame` with 100 rows and the following columns:
`['date', 'GOOG', 'AAPL', 'AMZN', 'FB', 'NFLX', 'MSFT']`.
If `indexed` is True, the 'date' column is used as the index and the column index
If `datetimes` is True, the 'date' column will be a datetime column
is named 'company'"""
df = _get_dataset("stocks")
if datetimes:
df["date"] = df["date"].astype("datetime64[ns]")
if indexed:
df = df.set_index("date")
df.columns.name = "company"
return df
def experiment(indexed=False):
"""
Each row in this wide dataset represents the results of 100 simulated participants
on three hypothetical experiments, along with their gender and control/treatment group.
Returns:
A `pandas.DataFrame` with 100 rows and the following columns:
`['experiment_1', 'experiment_2', 'experiment_3', 'gender', 'group']`.
If `indexed` is True, the data frame index is named "participant" """
df = _get_dataset("experiment")
if indexed:
df.index.name = "participant"
return df
def medals_wide(indexed=False):
"""
This dataset represents the medal table for Olympic Short Track Speed Skating for the
top three nations as of 2020.
Returns:
A `pandas.DataFrame` with 3 rows and the following columns:
`['nation', 'gold', 'silver', 'bronze']`.
If `indexed` is True, the 'nation' column is used as the index and the column index
is named 'medal'"""
df = _get_dataset("medals")
if indexed:
df = df.set_index("nation")
df.columns.name = "medal"
return df
def medals_long(indexed=False):
"""
This dataset represents the medal table for Olympic Short Track Speed Skating for the
top three nations as of 2020.
Returns:
A `pandas.DataFrame` with 9 rows and the following columns:
`['nation', 'medal', 'count']`.
If `indexed` is True, the 'nation' column is used as the index."""
df = _get_dataset("medals").melt(
id_vars=["nation"], value_name="count", var_name="medal"
)
if indexed:
df = df.set_index("nation")
return df
def _get_dataset(d):
import pandas
import os
return pandas.read_csv(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"package_data",
"datasets",
d + ".csv.gz",
)
)
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from _plotly_utils.exceptions import *
@@ -0,0 +1,112 @@
"""
`plotly.express` is a terse, consistent, high-level wrapper around `plotly.graph_objects`
for rapid data exploration and figure generation. Learn more at https://plotly.express/
"""
from __future__ import absolute_import
from plotly import optional_imports
pd = optional_imports.get_module("pandas")
if pd is None:
raise ImportError(
"""\
Plotly express requires pandas to be installed."""
)
from ._imshow import imshow
from ._chart_types import ( # noqa: F401
scatter,
scatter_3d,
scatter_polar,
scatter_ternary,
scatter_mapbox,
scatter_geo,
line,
line_3d,
line_polar,
line_ternary,
line_mapbox,
line_geo,
area,
bar,
timeline,
bar_polar,
violin,
box,
strip,
histogram,
ecdf,
scatter_matrix,
parallel_coordinates,
parallel_categories,
choropleth,
density_contour,
density_heatmap,
pie,
sunburst,
treemap,
icicle,
funnel,
funnel_area,
choropleth_mapbox,
density_mapbox,
)
from ._core import ( # noqa: F401
set_mapbox_access_token,
defaults,
get_trendline_results,
NO_COLOR,
)
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
from . import data, colors, trendline_functions # noqa: F401
__all__ = [
"scatter",
"scatter_3d",
"scatter_polar",
"scatter_ternary",
"scatter_mapbox",
"scatter_geo",
"scatter_matrix",
"density_contour",
"density_heatmap",
"density_mapbox",
"line",
"line_3d",
"line_polar",
"line_ternary",
"line_mapbox",
"line_geo",
"parallel_coordinates",
"parallel_categories",
"area",
"bar",
"timeline",
"bar_polar",
"violin",
"box",
"strip",
"histogram",
"ecdf",
"choropleth",
"choropleth_mapbox",
"pie",
"sunburst",
"treemap",
"icicle",
"funnel",
"funnel_area",
"imshow",
"data",
"colors",
"trendline_functions",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
"Range",
"NO_COLOR",
]
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,625 @@
import inspect
from textwrap import TextWrapper
try:
getfullargspec = inspect.getfullargspec
except AttributeError: # python 2
getfullargspec = inspect.getargspec
colref_type = "str or int or Series or array-like"
colref_desc = "Either a name of a column in `data_frame`, or a pandas Series or array_like object."
colref_list_type = "list of str or int, or Series or array-like"
colref_list_desc = (
"Either names of columns in `data_frame`, or pandas Series, or array_like objects"
)
docs = dict(
data_frame=[
"DataFrame or array-like or dict",
"This argument needs to be passed for column names (and not keyword names) to be used.",
"Array-like and dict are tranformed internally to a pandas DataFrame.",
"Optional: if missing, a DataFrame gets constructed under the hood using the other arguments.",
],
x=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
y=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.",
],
z=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.",
],
x_start=[
colref_type,
colref_desc,
"(required)",
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
x_end=[
colref_type,
colref_desc,
"(required)",
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
a=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the a axis in ternary coordinates.",
],
b=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the b axis in ternary coordinates.",
],
c=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the c axis in ternary coordinates.",
],
r=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the radial axis in polar coordinates.",
],
theta=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the angular axis in polar coordinates.",
],
values=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set values associated to sectors.",
],
parents=[
colref_type,
colref_desc,
"Values from this column or array_like are used as parents in sunburst and treemap charts.",
],
ids=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set ids of sectors",
],
path=[
colref_list_type,
colref_list_desc,
"List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.",
"An error is raised if path AND ids or parents is passed",
],
lat=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks according to latitude on a map.",
],
lon=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks according to longitude on a map.",
],
locations=[
colref_type,
colref_desc,
"Values from this column or array_like are to be interpreted according to `locationmode` and mapped to longitude/latitude.",
],
base=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position the base of the bar.",
],
dimensions=[
colref_list_type,
colref_list_desc,
"Values from these columns are used for multidimensional visualization.",
],
dimensions_max_cardinality=[
"int (default 50)",
"When `dimensions` is `None` and `data_frame` is provided, "
"columns with more than this number of unique values are excluded from the output.",
"Not used when `dimensions` is passed.",
],
error_x=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size x-axis error bars.",
"If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.",
],
error_x_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size x-axis error bars in the negative direction.",
"Ignored if `error_x` is `None`.",
],
error_y=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size y-axis error bars.",
"If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.",
],
error_y_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size y-axis error bars in the negative direction.",
"Ignored if `error_y` is `None`.",
],
error_z=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size z-axis error bars.",
"If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.",
],
error_z_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size z-axis error bars in the negative direction.",
"Ignored if `error_z` is `None`.",
],
color=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign color to marks.",
],
opacity=["float", "Value between 0 and 1. Sets the opacity for markers."],
line_dash=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign dash-patterns to lines.",
],
line_group=[
colref_type,
colref_desc,
"Values from this column or array_like are used to group rows of `data_frame` into lines.",
],
symbol=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign symbols to marks.",
],
pattern_shape=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign pattern shapes to marks.",
],
size=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign mark sizes.",
],
radius=["int (default is 30)", "Sets the radius of influence of each point."],
hover_name=[
colref_type,
colref_desc,
"Values from this column or array_like appear in bold in the hover tooltip.",
],
hover_data=[
"list of str or int, or Series or array-like, or dict",
"Either a list of names of columns in `data_frame`, or pandas Series,",
"or array_like objects",
"or a dict with column names as keys, with values True (for default formatting)",
"False (in order to remove this column from hover information),",
"or a formatting string, for example ':.3f' or '|%a'",
"or list-like data to appear in the hover tooltip",
"or tuples with a bool or formatting string as first element,",
"and list-like data to appear in hover as second element",
"Values from these columns appear as extra data in the hover tooltip.",
],
custom_data=[
colref_list_type,
colref_list_desc,
"Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
],
text=[
colref_type,
colref_desc,
"Values from this column or array_like appear in the figure as text labels.",
],
names=[
colref_type,
colref_desc,
"Values from this column or array_like are used as labels for sectors.",
],
locationmode=[
"str",
"One of 'ISO-3', 'USA-states', or 'country names'",
"Determines the set of locations used to match entries in `locations` to regions on the map.",
],
facet_row=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.",
],
facet_col=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
],
facet_col_wrap=[
"int",
"Maximum number of facet columns.",
"Wraps the column variable at this width, so that the column facets span multiple rows.",
"Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
],
facet_row_spacing=[
"float between 0 and 1",
"Spacing between facet rows, in paper units. Default is 0.03 or 0.0.7 when facet_col_wrap is used.",
],
facet_col_spacing=[
"float between 0 and 1",
"Spacing between facet columns, in paper units Default is 0.02.",
],
animation_frame=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to animation frames.",
],
animation_group=[
colref_type,
colref_desc,
"Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.",
],
symbol_sequence=[
"list of str",
"Strings should define valid plotly.js symbols.",
"When `symbol` is set, values in that column are assigned symbols by cycling through `symbol_sequence` in the order described in `category_orders`, unless the value of `symbol` is a key in `symbol_map`.",
],
symbol_map=[
"dict with str keys and str values (default `{}`)",
"String values should define plotly.js symbols",
"Used to override `symbol_sequence` to assign a specific symbols to marks corresponding with specific values.",
"Keys in `symbol_map` should be values in the column denoted by `symbol`.",
"Alternatively, if the values of `symbol` are valid symbol names, the string `'identity'` may be passed to cause them to be used directly.",
],
line_dash_map=[
"dict with str keys and str values (default `{}`)",
"Strings values define plotly.js dash-patterns.",
"Used to override `line_dash_sequences` to assign a specific dash-patterns to lines corresponding with specific values.",
"Keys in `line_dash_map` should be values in the column denoted by `line_dash`.",
"Alternatively, if the values of `line_dash` are valid line-dash names, the string `'identity'` may be passed to cause them to be used directly.",
],
line_dash_sequence=[
"list of str",
"Strings should define valid plotly.js dash-patterns.",
"When `line_dash` is set, values in that column are assigned dash-patterns by cycling through `line_dash_sequence` in the order described in `category_orders`, unless the value of `line_dash` is a key in `line_dash_map`.",
],
pattern_shape_map=[
"dict with str keys and str values (default `{}`)",
"Strings values define plotly.js patterns-shapes.",
"Used to override `pattern_shape_sequences` to assign a specific patterns-shapes to lines corresponding with specific values.",
"Keys in `pattern_shape_map` should be values in the column denoted by `pattern_shape`.",
"Alternatively, if the values of `pattern_shape` are valid patterns-shapes names, the string `'identity'` may be passed to cause them to be used directly.",
],
pattern_shape_sequence=[
"list of str",
"Strings should define valid plotly.js patterns-shapes.",
"When `pattern_shape` is set, values in that column are assigned patterns-shapes by cycling through `pattern_shape_sequence` in the order described in `category_orders`, unless the value of `pattern_shape` is a key in `pattern_shape_map`.",
],
color_discrete_sequence=[
"list of str",
"Strings should define valid CSS-colors.",
"When `color` is set and the values in the corresponding column are not numeric, values in that column are assigned colors by cycling through `color_discrete_sequence` in the order described in `category_orders`, unless the value of `color` is a key in `color_discrete_map`.",
"Various useful color sequences are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.qualitative`.",
],
color_discrete_map=[
"dict with str keys and str values (default `{}`)",
"String values should define valid CSS-colors",
"Used to override `color_discrete_sequence` to assign a specific colors to marks corresponding with specific values.",
"Keys in `color_discrete_map` should be values in the column denoted by `color`.",
"Alternatively, if the values of `color` are valid colors, the string `'identity'` may be passed to cause them to be used directly.",
],
color_continuous_scale=[
"list of str",
"Strings should define valid CSS-colors",
"This list is used to build a continuous color scale when the column denoted by `color` contains numeric data.",
"Various useful color scales are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.sequential`, `plotly.express.colors.diverging` and `plotly.express.colors.cyclical`.",
],
color_continuous_midpoint=[
"number (default `None`)",
"If set, computes the bounds of the continuous color scale to have the desired midpoint.",
"Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.",
],
size_max=["int (default `20`)", "Set the maximum mark size when using `size`."],
markers=["boolean (default `False`)", "If `True`, markers are shown on lines."],
lines=[
"boolean (default `True`)",
"If `False`, lines are not drawn (forced to `True` if `markers` is `False`).",
],
log_x=[
"boolean (default `False`)",
"If `True`, the x-axis is log-scaled in cartesian coordinates.",
],
log_y=[
"boolean (default `False`)",
"If `True`, the y-axis is log-scaled in cartesian coordinates.",
],
log_z=[
"boolean (default `False`)",
"If `True`, the z-axis is log-scaled in cartesian coordinates.",
],
log_r=[
"boolean (default `False`)",
"If `True`, the radial axis is log-scaled in polar coordinates.",
],
range_x=[
"list of two numbers",
"If provided, overrides auto-scaling on the x-axis in cartesian coordinates.",
],
range_y=[
"list of two numbers",
"If provided, overrides auto-scaling on the y-axis in cartesian coordinates.",
],
range_z=[
"list of two numbers",
"If provided, overrides auto-scaling on the z-axis in cartesian coordinates.",
],
range_color=[
"list of two numbers",
"If provided, overrides auto-scaling on the continuous color scale.",
],
range_r=[
"list of two numbers",
"If provided, overrides auto-scaling on the radial axis in polar coordinates.",
],
range_theta=[
"list of two numbers",
"If provided, overrides auto-scaling on the angular axis in polar coordinates.",
],
title=["str", "The figure title."],
template=[
"str or dict or plotly.graph_objects.layout.Template instance",
"The figure template name (must be a key in plotly.io.templates) or definition.",
],
width=["int (default `None`)", "The figure width in pixels."],
height=["int (default `None`)", "The figure height in pixels."],
labels=[
"dict with str keys and str values (default `{}`)",
"By default, column names are used in the figure for axis titles, legend entries and hovers.",
"This parameter allows this to be overridden.",
"The keys of this dict should correspond to column names, and the values should correspond to the desired label to be displayed.",
],
category_orders=[
"dict with str keys and list of str values (default `{}`)",
"By default, in Python 3.6+, the order of categorical values in axes, legends and facets depends on the order in which these values are first encountered in `data_frame` (and no order is guaranteed by default in Python below 3.6).",
"This parameter is used to force a specific ordering of values per column.",
"The keys of this dict should correspond to column names, and the values should be lists of strings corresponding to the specific display order desired.",
],
marginal=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a subplot is drawn alongside the main plot, visualizing the distribution.",
],
marginal_x=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a horizontal subplot is drawn above the main plot, visualizing the x-distribution.",
],
marginal_y=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a vertical subplot is drawn to the right of the main plot, visualizing the y-distribution.",
],
trendline=[
"str",
"One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.",
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
"If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.",
"If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.",
"If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.",
"See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how",
"to configure them with the `trendline_options` argument.",
],
trendline_options=[
"dict",
"Options passed as the first argument to the function from `plotly.express.trendline_functions` ",
"named in the `trendline` argument.",
],
trendline_color_override=[
"str",
"Valid CSS color.",
"If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.",
],
trendline_scope=[
"str (one of `'trace'` or `'overall'`, default `'trace'`)",
"If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.",
],
render_mode=[
"str",
"One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`",
"Controls the browser API used to draw marks.",
"`'svg`' is appropriate for figures of less than 1000 data points, and will allow for fully-vectorized output.",
"`'webgl'` is likely necessary for acceptable performance above 1000 points but rasterizes part of the output. ",
"`'auto'` uses heuristics to choose the mode.",
],
direction=[
"str",
"One of '`counterclockwise'` or `'clockwise'`. Default is `'clockwise'`",
"Sets the direction in which increasing values of the angular axis are drawn.",
],
start_angle=[
"int (default `90`)",
"Sets start angle for the angular axis, with 0 being due east and 90 being due north.",
],
histfunc=[
"str (default `'count'` if no arguments are provided, else `'sum'`)",
"One of `'count'`, `'sum'`, `'avg'`, `'min'`, or `'max'`."
"Function used to aggregate values for summarization (note: can be normalized with `histnorm`).",
],
histnorm=[
"str (default `None`)",
"One of `'percent'`, `'probability'`, `'density'`, or `'probability density'`",
"If `None`, the output of `histfunc` is used as is.",
"If `'probability'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins.",
"If `'percent'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins and multiplied by 100.",
"If `'density'`, the output of `histfunc` for a given bin is divided by the size of the bin.",
"If `'probability density'`, the output of `histfunc` for a given bin is normalized such that it corresponds to the probability that a random event whose distribution is described by the output of `histfunc` will fall into that bin.",
],
barnorm=[
"str (default `None`)",
"One of `'fraction'` or `'percent'`.",
"If `'fraction'`, the value of each bar is divided by the sum of all values at that location coordinate.",
"`'percent'` is the same but multiplied by 100 to show percentages.",
"`None` will stack up all values at each location coordinate.",
],
groupnorm=[
"str (default `None`)",
"One of `'fraction'` or `'percent'`.",
"If `'fraction'`, the value of each point is divided by the sum of all values at that location coordinate.",
"`'percent'` is the same but multiplied by 100 to show percentages.",
"`None` will stack up all values at each location coordinate.",
],
barmode=[
"str (default `'relative'`)",
"One of `'group'`, `'overlay'` or `'relative'`",
"In `'relative'` mode, bars are stacked above zero for positive values and below zero for negative values.",
"In `'overlay'` mode, bars are drawn on top of one another.",
"In `'group'` mode, bars are placed beside each other.",
],
boxmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, boxes are on drawn top of one another.",
"In `'group'` mode, boxes are placed beside each other.",
],
violinmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, violins are on drawn top of one another.",
"In `'group'` mode, violins are placed beside each other.",
],
stripmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, strips are on drawn top of one another.",
"In `'group'` mode, strips are placed beside each other.",
],
zoom=["int (default `8`)", "Between 0 and 20.", "Sets map zoom level."],
orientation=[
"str, one of `'h'` for horizontal or `'v'` for vertical. ",
"(default `'v'` if `x` and `y` are provided and both continous or both categorical, ",
"otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and `y`(`x`) is continuous, ",
"otherwise `'v'`(`'h'`) if only `x`(`y`) is provided) ",
],
line_close=[
"boolean (default `False`)",
"If `True`, an extra line segment is drawn between the first and last point.",
],
line_shape=["str (default `'linear'`)", "One of `'linear'` or `'spline'`."],
fitbounds=["str (default `False`).", "One of `False`, `locations` or `geojson`."],
basemap_visible=["bool", "Force the basemap visibility."],
scope=[
"str (default `'world'`).",
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`"
"Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`.",
],
projection=[
"str ",
"One of `'equirectangular'`, `'mercator'`, `'orthographic'`, `'natural earth'`, `'kavrayskiy7'`, `'miller'`, `'robinson'`, `'eckert4'`, `'azimuthal equal area'`, `'azimuthal equidistant'`, `'conic equal area'`, `'conic conformal'`, `'conic equidistant'`, `'gnomonic'`, `'stereographic'`, `'mollweide'`, `'hammer'`, `'transverse mercator'`, `'albers usa'`, `'winkel tripel'`, `'aitoff'`, or `'sinusoidal'`"
"Default depends on `scope`.",
],
center=[
"dict",
"Dict keys are `'lat'` and `'lon'`",
"Sets the center point of the map.",
],
mapbox_style=[
"str (default `'basic'`, needs Mapbox API token)",
"Identifier of base map style, some of which require a Mapbox API token to be set using `plotly.express.set_mapbox_access_token()`.",
"Allowed values which do not require a Mapbox API token are `'open-street-map'`, `'white-bg'`, `'carto-positron'`, `'carto-darkmatter'`, `'stamen-terrain'`, `'stamen-toner'`, `'stamen-watercolor'`.",
"Allowed values which do require a Mapbox API token are `'basic'`, `'streets'`, `'outdoors'`, `'light'`, `'dark'`, `'satellite'`, `'satellite-streets'`.",
],
points=[
"str or boolean (default `'outliers'`)",
"One of `'outliers'`, `'suspectedoutliers'`, `'all'`, or `False`.",
"If `'outliers'`, only the sample points lying outside the whiskers are shown.",
"If `'suspectedoutliers'`, all outlier points are shown and those less than 4*Q1-3*Q3 or greater than 4*Q3-3*Q1 are highlighted with the marker's `'outliercolor'`.",
"If `'outliers'`, only the sample points lying outside the whiskers are shown.",
"If `'all'`, all sample points are shown.",
"If `False`, no sample points are shown and the whiskers extend to the full range of the sample.",
],
box=["boolean (default `False`)", "If `True`, boxes are drawn inside the violins."],
notched=["boolean (default `False`)", "If `True`, boxes are drawn with notches."],
geojson=[
"GeoJSON-formatted dict",
"Must contain a Polygon feature collection, with IDs, which are references from `locations`.",
],
featureidkey=[
"str (default: `'id'`)",
"Path to field in GeoJSON feature object with which to match the values passed in to `locations`."
"The most common alternative to the default is of the form `'properties.<key>`.",
],
cumulative=[
"boolean (default `False`)",
"If `True`, histogram values are cumulative.",
],
nbins=["int", "Positive integer.", "Sets the number of bins."],
nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."],
nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."],
branchvalues=[
"str",
"'total' or 'remainder'",
"Determines how the items in `values` are summed. When"
"set to 'total', items in `values` are taken to be value"
"of all its descendants. When set to 'remainder', items"
"in `values` corresponding to the root and the branches"
":sectors are taken to be the extra part not part of the"
"sum of the values at their leaves.",
],
maxdepth=[
"int",
"Positive integer",
"Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
"levels in the hierarchy.",
],
ecdfnorm=[
"string or `None` (default `'probability'`)",
"One of `'probability'` or `'percent'`",
"If `None`, values will be raw counts or sums.",
"If `'probability', values will be probabilities normalized from 0 to 1.",
"If `'percent', values will be percentages normalized from 0 to 100.",
],
ecdfmode=[
"string (default `'standard'`)",
"One of `'standard'`, `'complementary'` or `'reversed'`",
"If `'standard'`, the ECDF is plotted such that values represent data at or below the point.",
"If `'complementary'`, the CCDF is plotted such that values represent data above the point.",
"If `'reversed'`, a variant of the CCDF is plotted such that values represent data at or above the point.",
],
text_auto=[
"bool or string (default `False`)",
"If `True` or a string, the x or y or z values will be displayed as text, depending on the orientation",
"A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.",
],
)
def make_docstring(fn, override_dict=None, append_dict=None):
override_dict = {} if override_dict is None else override_dict
append_dict = {} if append_dict is None else append_dict
tw = TextWrapper(width=75, initial_indent=" ", subsequent_indent=" ")
result = (fn.__doc__ or "") + "\nParameters\n----------\n"
for param in getfullargspec(fn)[0]:
if override_dict.get(param):
param_doc = list(override_dict[param])
else:
param_doc = list(docs[param])
if append_dict.get(param):
param_doc += append_dict[param]
param_desc_list = param_doc[1:]
param_desc = (
tw.fill(" ".join(param_desc_list or ""))
if param in docs or param in override_dict
else "(documentation missing from map)"
)
param_type = param_doc[0]
result += "%s: %s\n%s\n" % (param, param_type, param_desc)
result += "\nReturns\n-------\n"
result += " plotly.graph_objects.Figure"
return result
@@ -0,0 +1,600 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade, init_figure, configure_animation_controls
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
import numpy as np
import itertools
from plotly.utils import image_array_to_data_uri
try:
import xarray
xarray_imported = True
except ImportError:
xarray_imported = False
_float_types = []
def _vectorize_zvalue(z, mode="max"):
alpha = 255 if mode == "max" else 0
if z is None:
return z
elif np.isscalar(z):
return [z] * 3 + [alpha]
elif len(z) == 1:
return list(z) * 3 + [alpha]
elif len(z) == 3:
return list(z) + [alpha]
elif len(z) == 4:
return z
else:
raise ValueError(
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
"A value of %s was passed for zmax." % str(z)
)
def _infer_zmax_from_type(img):
dt = img.dtype.type
rtol = 1.05
if dt in _integer_types:
return _integer_ranges[dt][1]
else:
im_max = img[np.isfinite(img)].max()
if im_max <= 1 * rtol:
return 1
elif im_max <= 255 * rtol:
return 255
elif im_max <= 65535 * rtol:
return 65535
else:
return 2**32
def imshow(
img,
zmin=None,
zmax=None,
origin=None,
labels={},
x=None,
y=None,
animation_frame=None,
facet_col=None,
facet_col_wrap=None,
facet_col_spacing=None,
facet_row_spacing=None,
color_continuous_scale=None,
color_continuous_midpoint=None,
range_color=None,
title=None,
template=None,
width=None,
height=None,
aspect=None,
contrast_rescaling=None,
binary_string=None,
binary_backend="auto",
binary_compression_level=4,
binary_format="png",
text_auto=False,
) -> go.Figure:
"""
Display an image, i.e. data on a 2D regular raster.
Parameters
----------
img: array-like image, or xarray
The image data. Supported array shapes are
- (M, N): an image with scalar data. The data is visualized
using a colormap.
- (M, N, 3): an image with RGB values.
- (M, N, 4): an image with RGBA values, i.e. including transparency.
zmin, zmax : scalar or iterable, optional
zmin and zmax define the scalar range that the colormap covers. By default,
zmin and zmax correspond to the min and max values of the datatype for integer
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
a multichannel image of floats, the max of the image is computed and zmax is the
smallest power of 256 (1, 255, 65535) greater than this max value,
with a 5% tolerance. For a single-channel image, the max of the image is used.
Overridden by range_color.
origin : str, 'upper' or 'lower' (default 'upper')
position of the [0, 0] pixel of the image array, in the upper left or lower left
corner. The convention 'upper' is typically used for matrices and images.
labels : dict with str keys and str values (default `{}`)
Sets names used in the figure for axis titles (keys ``x`` and ``y``),
colorbar title and hoverlabel (key ``color``). The values should correspond
to the desired label to be displayed. If ``img`` is an xarray, dimension
names are used for axis titles, and long name for the colorbar title
(unless overridden in ``labels``). Possible keys are: x, y, and color.
x, y: list-like, optional
x and y are used to label the axes of single-channel heatmap visualizations and
their lengths must match the lengths of the second and first dimensions of the
img argument. They are auto-populated if the input is an xarray.
animation_frame: int or str, optional (default None)
axis number along which the image array is sliced to create an animation plot.
If `img` is an xarray, `animation_frame` can be the name of one the dimensions.
facet_col: int or str, optional (default None)
axis number along which the image array is sliced to create a facetted plot.
If `img` is an xarray, `facet_col` can be the name of one the dimensions.
facet_col_wrap: int
Maximum number of facet columns. Wraps the column variable at this width,
so that the column facets span multiple rows.
Ignored if `facet_col` is None.
facet_col_spacing: float between 0 and 1
Spacing between facet columns, in paper units. Default is 0.02.
facet_row_spacing: float between 0 and 1
Spacing between facet rows created when ``facet_col_wrap`` is used, in
paper units. Default is 0.0.7.
color_continuous_scale : str or list of str
colormap used to map scalar data to colors (for a 2D image). This parameter is
not used for RGB or RGBA images. If a string is provided, it should be the name
of a known color scale, and if a list is provided, it should be a list of CSS-
compatible colors.
color_continuous_midpoint : number
If set, computes the bounds of the continuous color scale to have the desired
midpoint. Overridden by range_color or zmin and zmax.
range_color : list of two numbers
If provided, overrides auto-scaling on the continuous color scale, including
overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only
for single-channel images.
title : str
The figure title.
template : str or dict or plotly.graph_objects.layout.Template instance
The figure template name or definition.
width : number
The figure width in pixels.
height: number
The figure height in pixels.
aspect: 'equal', 'auto', or None
- 'equal': Ensures an aspect ratio of 1 or pixels (square pixels)
- 'auto': The axes is kept fixed and the aspect ratio of pixels is
adjusted so that the data fit in the axes. In general, this will
result in non-square pixels.
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
(which have typically heterogeneous coordinates)
contrast_rescaling: 'minmax', 'infer', or None
how to determine data values corresponding to the bounds of the color
range, when zmin or zmax are not passed. If `minmax`, the min and max
values of the image are used. If `infer`, a heuristic based on the image
data type is used.
binary_string: bool, default None
if True, the image data are first rescaled and encoded as uint8 and
then passed to plotly.js as a b64 PNG string. If False, data are passed
unchanged as a numerical array. Setting to True may lead to performance
gains, at the cost of a loss of precision depending on the original data
type. If None, use_binary_string is set to True for multichannel (eg) RGB
arrays, and to False for single-channel (2D) arrays. 2D arrays are
represented as grayscale and with no colorbar if use_binary_string is
True.
binary_backend: str, 'auto' (default), 'pil' or 'pypng'
Third-party package for the transformation of numpy arrays to
png b64 strings. If 'auto', Pillow is used if installed, otherwise
pypng.
binary_compression_level: int, between 0 and 9 (default 4)
png compression level to be passed to the backend when transforming an
array to a png b64 string. Increasing `binary_compression` decreases the
size of the png string, but the compression step takes more time. For most
images it is not worth using levels greater than 5, but it's possible to
test `len(fig.data[0].source)` and to time the execution of `imshow` to
tune the level of compression. 0 means no compression (not recommended).
binary_format: str, 'png' (default) or 'jpg'
compression format used to generate b64 string. 'png' is recommended
since it uses lossless compression, but 'jpg' (lossy) compression can
result if smaller binary strings for natural images.
text_auto: bool or str (default `False`)
If `True` or a string, single-channel `img` values will be displayed as text.
A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.
Returns
-------
fig : graph_objects.Figure containing the displayed image
See also
--------
plotly.graph_objects.Image : image trace
plotly.graph_objects.Heatmap : heatmap trace
Notes
-----
In order to update and customize the returned figure, use
`go.Figure.update_traces` or `go.Figure.update_layout`.
If an xarray is passed, dimensions names and coordinates are used for
axes labels and ticks.
"""
args = locals()
apply_default_cascade(args)
labels = labels.copy()
nslices_facet = 1
if facet_col is not None:
if isinstance(facet_col, str):
facet_col = img.dims.index(facet_col)
nslices_facet = img.shape[facet_col]
facet_slices = range(nslices_facet)
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
nrows = (
nslices_facet // ncols + 1
if nslices_facet % ncols
else nslices_facet // ncols
)
else:
nrows = 1
ncols = 1
if animation_frame is not None:
if isinstance(animation_frame, str):
animation_frame = img.dims.index(animation_frame)
nslices_animation = img.shape[animation_frame]
animation_slices = range(nslices_animation)
slice_dimensions = (facet_col is not None) + (
animation_frame is not None
) # 0, 1, or 2
facet_label = None
animation_label = None
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
dims = list(img.dims)
img_is_xarray = True
if facet_col is not None:
facet_slices = img.coords[img.dims[facet_col]].values
_ = dims.pop(facet_col)
facet_label = img.dims[facet_col]
if animation_frame is not None:
animation_slices = img.coords[img.dims[animation_frame]].values
_ = dims.pop(animation_frame)
animation_label = img.dims[animation_frame]
y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
if x is None:
x = img.coords[x_label].values
if y is None:
y = img.coords[y_label].values
if aspect is None:
aspect = "auto"
if labels.get("x", None) is None:
labels["x"] = x_label
if labels.get("y", None) is None:
labels["y"] = y_label
if labels.get("animation_frame", None) is None:
labels["animation_frame"] = animation_label
if labels.get("facet_col", None) is None:
labels["facet_col"] = facet_label
if labels.get("color", None) is None:
labels["color"] = xarray.plot.utils.label_from_attrs(img)
labels["color"] = labels["color"].replace("\n", "<br>")
else:
if hasattr(img, "columns") and hasattr(img.columns, "__len__"):
if x is None:
x = img.columns
if labels.get("x", None) is None and hasattr(img.columns, "name"):
labels["x"] = img.columns.name or ""
if hasattr(img, "index") and hasattr(img.index, "__len__"):
if y is None:
y = img.index
if labels.get("y", None) is None and hasattr(img.index, "name"):
labels["y"] = img.index.name or ""
if labels.get("x", None) is None:
labels["x"] = ""
if labels.get("y", None) is None:
labels["y"] = ""
if labels.get("color", None) is None:
labels["color"] = ""
if aspect is None:
aspect = "equal"
# --- Set the value of binary_string (forbidden for pandas)
if isinstance(img, pd.DataFrame):
if binary_string:
raise ValueError("Binary strings cannot be used with pandas arrays")
is_dataframe = True
else:
is_dataframe = False
# --------------- Starting from here img is always a numpy array --------
img = np.asanyarray(img)
# Reshape array so that animation dimension comes first, then facets, then images
if facet_col is not None:
img = np.moveaxis(img, facet_col, 0)
if animation_frame is not None and animation_frame < facet_col:
animation_frame += 1
facet_col = True
if animation_frame is not None:
img = np.moveaxis(img, animation_frame, 0)
animation_frame = True
args["animation_frame"] = (
"animation_frame"
if labels.get("animation_frame") is None
else labels["animation_frame"]
)
iterables = ()
if animation_frame is not None:
iterables += (range(nslices_animation),)
if facet_col is not None:
iterables += (range(nslices_facet),)
# Default behaviour of binary_string: True for RGB images, False for 2D
if binary_string is None:
binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe
# Cast bools to uint8 (also one byte)
if img.dtype == bool:
img = 255 * img.astype(np.uint8)
if range_color is not None:
zmin = range_color[0]
zmax = range_color[1]
# -------- Contrast rescaling: either minmax or infer ------------------
if contrast_rescaling is None:
contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"
# We try to set zmin and zmax only if necessary, because traces have good defaults
if contrast_rescaling == "minmax":
# When using binary_string and minmax we need to set zmin and zmax to rescale the image
if (zmin is not None or binary_string) and zmax is None:
zmax = img.max()
if (zmax is not None or binary_string) and zmin is None:
zmin = img.min()
else:
# For uint8 data and infer we let zmin and zmax to be None if passed as None
if zmax is None and img.dtype != np.uint8:
zmax = _infer_zmax_from_type(img)
if zmin is None and zmax is not None:
zmin = 0
# For 2d data, use Heatmap trace, unless binary_string is True
if img.ndim == 2 + slice_dimensions and not binary_string:
y_index = slice_dimensions
if y is not None and img.shape[y_index] != len(y):
raise ValueError(
"The length of the y vector must match the length of the first "
+ "dimension of the img matrix."
)
x_index = slice_dimensions + 1
if x is not None and img.shape[x_index] != len(x):
raise ValueError(
"The length of the x vector must match the length of the second "
+ "dimension of the img matrix."
)
texttemplate = None
if text_auto is True:
texttemplate = "%{z}"
elif text_auto is not False:
texttemplate = "%{z:" + text_auto + "}"
traces = [
go.Heatmap(
x=x,
y=y,
z=img[index_tup],
coloraxis="coloraxis1",
name=str(i),
texttemplate=texttemplate,
)
for i, index_tup in enumerate(itertools.product(*iterables))
]
autorange = True if origin == "lower" else "reversed"
layout = dict(yaxis=dict(autorange=autorange))
if aspect == "equal":
layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
layout["yaxis"]["constrain"] = "domain"
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
layout["coloraxis1"] = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=color_continuous_midpoint,
cmin=zmin,
cmax=zmax,
)
if labels["color"]:
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
# For 2D+RGB data, use Image trace
elif (
img.ndim >= 3
and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
) or (img.ndim == 2 and binary_string):
rescale_image = True # to check whether image has been modified
if zmin is not None and zmax is not None:
zmin, zmax = (
_vectorize_zvalue(zmin, mode="min"),
_vectorize_zvalue(zmax, mode="max"),
)
x0, y0, dx, dy = (None,) * 4
error_msg_xarray = (
"Non-numerical coordinates were passed with xarray `img`, but "
"the Image trace cannot handle it. Please use `binary_string=False` "
"for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
)
if x is not None:
x = np.asanyarray(x)
if np.issubdtype(x.dtype, np.number):
x0 = x[0]
dx = x[1] - x[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `x` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if y is not None:
y = np.asanyarray(y)
if np.issubdtype(y.dtype, np.number):
y0 = y[0]
dy = y[1] - y[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `y` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if binary_string:
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
rescale_image = False
elif img.ndim == 2 + slice_dimensions: # single-channel image
img_rescaled = rescale_intensity(
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
)
else:
img_rescaled = np.stack(
[
rescale_intensity(
img[..., ch],
in_range=(zmin[ch], zmax[ch]),
out_range=np.uint8,
)
for ch in range(img.shape[-1])
],
axis=-1,
)
img_str = [
image_array_to_data_uri(
img_rescaled[index_tup],
backend=binary_backend,
compression=binary_compression_level,
ext=binary_format,
)
for index_tup in itertools.product(*iterables)
]
traces = [
go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy)
for i, img_str_slice in enumerate(img_str)
]
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
traces = [
go.Image(
z=img[index_tup],
zmin=zmin,
zmax=zmax,
colormodel=colormodel,
x0=x0,
y0=y0,
dx=dx,
dy=dy,
)
for index_tup in itertools.product(*iterables)
]
layout = {}
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
if dx is not None and dx < 0:
layout["xaxis"] = dict(autorange="reversed")
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
"An image of shape %s was provided. "
"Alternatively, 3- or 4-D single or multichannel datasets can be "
"visualized using the `facet_col` or/and `animation_frame` arguments."
% str(img.shape)
)
# Now build figure
col_labels = []
if facet_col is not None:
slice_label = (
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
)
col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
for attr_name in ["height", "width"]:
if args[attr_name]:
layout[attr_name] = args[attr_name]
if args["title"]:
layout["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
layout["margin"] = {"t": 60}
frame_list = []
for index, trace in enumerate(traces):
if (facet_col and index < nrows * ncols) or index == 0:
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
if animation_frame is not None:
for i, index in zip(range(nslices_animation), animation_slices):
frame_list.append(
dict(
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
layout=layout,
name=str(index),
)
)
if animation_frame:
fig.frames = frame_list
fig.update_layout(layout)
# Hover name, z or color
if binary_string and rescale_image and not np.all(img == img_rescaled):
# we rescaled the image, hence z is not displayed in hover since it does
# not correspond to img values
hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % (
labels["x"] or "x",
labels["y"] or "y",
)
else:
if trace["type"] == "heatmap":
hover_name = "%{z}"
elif img.ndim == 2:
hover_name = "%{z[0]}"
elif img.ndim == 3 and img.shape[-1] == 3:
hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]"
else:
hover_name = "%{z}"
hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % (
labels["x"] or "x",
labels["y"] or "y",
labels["color"] or "color",
hover_name,
)
fig.update_traces(hovertemplate=hovertemplate)
if labels["x"]:
fig.update_xaxes(title_text=labels["x"], row=1)
if labels["y"]:
fig.update_yaxes(title_text=labels["y"], col=1)
configure_animation_controls(args, go.Image, fig)
fig.update_layout(template=args["template"], overwrite=True)
return fig
@@ -0,0 +1,40 @@
class IdentityMap(object):
"""
`dict`-like object which acts as if the value for any key is the key itself. Objects
of this class can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
functions, such as `line_dash_map` and `symbol_map`.
"""
def __getitem__(self, key):
return key
def __contains__(self, key):
return True
def copy(self):
return self
class Constant(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should take on a
constant value. An optional label can be provided.
"""
def __init__(self, value, label=None):
self.value = value
self.label = label
class Range(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should be mapped
onto integers starting at 0. An optional label can be provided.
"""
def __init__(self, label=None):
self.label = label
@@ -0,0 +1,52 @@
"""For a list of colors available in `plotly.express.colors`, please see
* the `tutorial on discrete color sequences <https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express>`_
* the `list of built-in continuous color scales <https://plotly.com/python/builtin-colorscales/>`_
* the `tutorial on continuous colors <https://plotly.com/python/colorscales/>`_
Color scales are available within the following namespaces
* cyclical
* diverging
* qualitative
* sequential
"""
from __future__ import absolute_import
from plotly.colors import *
__all__ = [
"named_colorscales",
"cyclical",
"diverging",
"sequential",
"qualitative",
"colorbrewer",
"colorbrewer",
"carto",
"cmocean",
"color_parser",
"colorscale_to_colors",
"colorscale_to_scale",
"convert_colors_to_same_type",
"convert_colorscale_to_rgb",
"convert_dict_colors_to_same_type",
"convert_to_RGB_255",
"find_intermediate_color",
"hex_to_rgb",
"label_rgb",
"make_colorscale",
"n_colors",
"unconvert_from_RGB_255",
"unlabel_rgb",
"validate_colors",
"validate_colors_dict",
"validate_colorscale",
"validate_scale_values",
"plotlyjs",
"DEFAULT_PLOTLY_COLORS",
"PLOTLY_SCALES",
"get_colorscale",
"sample_colorscale",
]
@@ -0,0 +1,19 @@
"""Built-in datasets for demonstration, educational and test purposes.
"""
from __future__ import absolute_import
from plotly.data import *
__all__ = [
"carshare",
"election",
"election_geojson",
"experiment",
"gapminder",
"iris",
"medals_wide",
"medals_long",
"stocks",
"tips",
"wind",
]
@@ -0,0 +1,248 @@
"""Vendored code from scikit-image in order to limit the number of dependencies
Extracted from scikit-image/skimage/exposure/exposure.py
"""
import numpy as np
from warnings import warn
_integer_types = (
np.byte,
np.ubyte, # 8 bits
np.short,
np.ushort, # 16 bits
np.intc,
np.uintc, # 16 or 32 or 64 bits
np.int_,
np.uint, # 32 or 64 bits
np.longlong,
np.ulonglong,
) # 64 bits
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
dtype_range = {
np.bool_: (False, True),
np.bool8: (False, True),
np.float16: (-1, 1),
np.float32: (-1, 1),
np.float64: (-1, 1),
}
dtype_range.update(_integer_ranges)
DTYPE_RANGE = dtype_range.copy()
DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items())
DTYPE_RANGE.update(
{
"uint10": (0, 2**10 - 1),
"uint12": (0, 2**12 - 1),
"uint14": (0, 2**14 - 1),
"bool": dtype_range[np.bool_],
"float": dtype_range[np.float64],
}
)
def intensity_range(image, range_values="image", clip_negative=False):
"""Return image intensity range (min, max) based on desired value type.
Parameters
----------
image : array
Input image.
range_values : str or 2-tuple, optional
The image intensity range is configured by this parameter.
The possible values for this parameter are enumerated below.
'image'
Return image min/max as the range.
'dtype'
Return min/max of the image's dtype as the range.
dtype-name
Return intensity range based on desired `dtype`. Must be valid key
in `DTYPE_RANGE`. Note: `image` is ignored for this range type.
2-tuple
Return `range_values` as min/max intensities. Note that there's no
reason to use this function if you just want to specify the
intensity range explicitly. This option is included for functions
that use `intensity_range` to support all desired range types.
clip_negative : bool, optional
If True, clip the negative range (i.e. return 0 for min intensity)
even if the image dtype allows negative values.
"""
if range_values == "dtype":
range_values = image.dtype.type
if range_values == "image":
i_min = np.min(image)
i_max = np.max(image)
elif range_values in DTYPE_RANGE:
i_min, i_max = DTYPE_RANGE[range_values]
if clip_negative:
i_min = 0
else:
i_min, i_max = range_values
return i_min, i_max
def _output_dtype(dtype_or_range):
"""Determine the output dtype for rescale_intensity.
The dtype is determined according to the following rules:
- if ``dtype_or_range`` is a dtype, that is the output dtype.
- if ``dtype_or_range`` is a dtype string, that is the dtype used, unless
it is not a NumPy data type (e.g. 'uint12' for 12-bit unsigned integers),
in which case the data type that can contain it will be used
(e.g. uint16 in this case).
- if ``dtype_or_range`` is a pair of values, the output data type will be
float.
Parameters
----------
dtype_or_range : type, string, or 2-tuple of int/float
The desired range for the output, expressed as either a NumPy dtype or
as a (min, max) pair of numbers.
Returns
-------
out_dtype : type
The data type appropriate for the desired output.
"""
if type(dtype_or_range) in [list, tuple, np.ndarray]:
# pair of values: always return float.
return np.float_
if type(dtype_or_range) == type:
# already a type: return it
return dtype_or_range
if dtype_or_range in DTYPE_RANGE:
# string key in DTYPE_RANGE dictionary
try:
# if it's a canonical numpy dtype, convert
return np.dtype(dtype_or_range).type
except TypeError: # uint10, uint12, uint14
# otherwise, return uint16
return np.uint16
else:
raise ValueError(
"Incorrect value for out_range, should be a valid image data "
"type or a pair of values, got %s." % str(dtype_or_range)
)
def rescale_intensity(image, in_range="image", out_range="dtype"):
"""Return image after stretching or shrinking its intensity levels.
The desired intensity range of the input and output, `in_range` and
`out_range` respectively, are used to stretch or shrink the intensity range
of the input image. See examples below.
Parameters
----------
image : array
Image array.
in_range, out_range : str or 2-tuple, optional
Min and max intensity values of input and output image.
The possible values for this parameter are enumerated below.
'image'
Use image min/max as the intensity range.
'dtype'
Use min/max of the image's dtype as the intensity range.
dtype-name
Use intensity range based on desired `dtype`. Must be valid key
in `DTYPE_RANGE`.
2-tuple
Use `range_values` as explicit min/max intensities.
Returns
-------
out : array
Image array after rescaling its intensity. This image is the same dtype
as the input image.
Notes
-----
.. versionchanged:: 0.17
The dtype of the output array has changed to match the output dtype, or
float if the output range is specified by a pair of floats.
See Also
--------
equalize_hist
Examples
--------
By default, the min/max intensities of the input image are stretched to
the limits allowed by the image's dtype, since `in_range` defaults to
'image' and `out_range` defaults to 'dtype':
>>> image = np.array([51, 102, 153], dtype=np.uint8)
>>> rescale_intensity(image)
array([ 0, 127, 255], dtype=uint8)
It's easy to accidentally convert an image dtype from uint8 to float:
>>> 1.0 * image
array([ 51., 102., 153.])
Use `rescale_intensity` to rescale to the proper range for float dtypes:
>>> image_float = 1.0 * image
>>> rescale_intensity(image_float)
array([0. , 0.5, 1. ])
To maintain the low contrast of the original, use the `in_range` parameter:
>>> rescale_intensity(image_float, in_range=(0, 255))
array([0.2, 0.4, 0.6])
If the min/max value of `in_range` is more/less than the min/max image
intensity, then the intensity levels are clipped:
>>> rescale_intensity(image_float, in_range=(0, 102))
array([0.5, 1. , 1. ])
If you have an image with signed integers but want to rescale the image to
just the positive range, use the `out_range` parameter. In that case, the
output dtype will be float:
>>> image = np.array([-10, 0, 10], dtype=np.int8)
>>> rescale_intensity(image, out_range=(0, 127))
array([ 0. , 63.5, 127. ])
To get the desired range with a specific dtype, use ``.astype()``:
>>> rescale_intensity(image, out_range=(0, 127)).astype(np.int8)
array([ 0, 63, 127], dtype=int8)
If the input image is constant, the output will be clipped directly to the
output range:
>>> image = np.array([130, 130, 130], dtype=np.int32)
>>> rescale_intensity(image, out_range=(0, 127)).astype(np.int32)
array([127, 127, 127], dtype=int32)
"""
if out_range in ["dtype", "image"]:
out_dtype = _output_dtype(image.dtype.type)
else:
out_dtype = _output_dtype(out_range)
imin, imax = map(float, intensity_range(image, in_range))
omin, omax = map(
float, intensity_range(image, out_range, clip_negative=(imin >= 0))
)
if np.any(np.isnan([imin, imax, omin, omax])):
warn(
"One or more intensity levels are NaN. Rescaling will broadcast "
"NaN to the full image. Provide intensity levels yourself to "
"avoid this. E.g. with np.nanmin(image), np.nanmax(image).",
stacklevel=2,
)
image = np.clip(image, imin, imax)
if imin != imax:
image = (image - imin) / (imax - imin)
return np.asarray(image * (omax - omin) + omin, dtype=out_dtype)
else:
return np.clip(image, omin, omax).astype(out_dtype)
@@ -0,0 +1,157 @@
"""
The `trendline_functions` module contains functions which are called by Plotly Express
when the `trendline` argument is used. Valid values for `trendline` are the names of the
functions in this module, and the value of the `trendline_options` argument to PX
functions is passed in as the first argument to these functions when called.
Note that the functions in this module are not meant to be called directly, and are
exposed as part of the public API for documentation purposes.
"""
import pandas as pd
import numpy as np
__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"]
def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Ordinary Least Squares (OLS) trendline function
Requires `statsmodels` to be installed.
This trendline function causes fit results to be stored within the figure,
accessible via the `plotly.express.get_trendline_results` function. The fit results
are the output of the `statsmodels.api.OLS` function.
Valid keys for the `trendline_options` dict are:
- `add_constant` (`bool`, default `True`): if `False`, the trendline passes through
the origin but if `True` a y-intercept is fitted.
- `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
respect to the base 10 logarithm of the input. Note that this means no zeros can
be present in the input.
"""
valid_options = ["add_constant", "log_x", "log_y"]
for k in trendline_options.keys():
if k not in valid_options:
raise ValueError(
"OLS trendline_options keys must be one of [%s] but got '%s'"
% (", ".join(valid_options), k)
)
import statsmodels.api as sm
add_constant = trendline_options.get("add_constant", True)
log_x = trendline_options.get("log_x", False)
log_y = trendline_options.get("log_y", False)
if log_y:
if np.any(y <= 0):
raise ValueError(
"Can't do OLS trendline with `log_y=True` when `y` contains non-positive values."
)
y = np.log10(y)
y_label = "log10(%s)" % y_label
if log_x:
if np.any(x <= 0):
raise ValueError(
"Can't do OLS trendline with `log_x=True` when `x` contains non-positive values."
)
x = np.log10(x)
x_label = "log10(%s)" % x_label
if add_constant:
x = sm.add_constant(x)
fit_results = sm.OLS(y, x, missing="drop").fit()
y_out = fit_results.predict()
if log_y:
y_out = np.power(10, y_out)
hover_header = "<b>OLS trendline</b><br>"
if len(fit_results.params) == 2:
hover_header += "%s = %g * %s + %g<br>" % (
y_label,
fit_results.params[1],
x_label,
fit_results.params[0],
)
elif not add_constant:
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
else:
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
return y_out, hover_header, fit_results
def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function
Requires `statsmodels` to be installed.
Valid keys for the `trendline_options` dict are:
- `frac` (`float`, default `0.6666666`): the `frac` parameter from the
`statsmodels.api.nonparametric.lowess` function
"""
valid_options = ["frac"]
for k in trendline_options.keys():
if k not in valid_options:
raise ValueError(
"LOWESS trendline_options keys must be one of [%s] but got '%s'"
% (", ".join(valid_options), k)
)
import statsmodels.api as sm
frac = trendline_options.get("frac", 0.6666666)
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
hover_header = "<b>LOWESS trendline</b><br><br>"
return y_out, hover_header, None
def _pandas(mode, trendline_options, x_raw, y, non_missing):
modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
trendline_options = trendline_options.copy()
function_name = trendline_options.pop("function", "mean")
function_args = trendline_options.pop("function_args", dict())
series = pd.Series(y, index=x_raw)
agg = getattr(series, mode) # e.g. series.rolling
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
y_out = y_out[non_missing]
hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
return y_out, hover_header, None
def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Rolling trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.rolling` function.
"""
return _pandas("rolling", trendline_options, x_raw, y, non_missing)
def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Expanding trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.expanding` function.
"""
return _pandas("expanding", trendline_options, x_raw, y, non_missing)
def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Exponentially Weighted Moment (EWM) trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.ewm` function.
"""
return _pandas("ewm", trendline_options, x_raw, y, non_missing)
@@ -0,0 +1,157 @@
from __future__ import absolute_import
from numbers import Number
import plotly.exceptions
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
def make_linear_colorscale(colors):
"""
Makes a list of colors into a colorscale-acceptable form
For documentation regarding to the form of the output, see
https://plot.ly/python/reference/#mesh3d-colorscale
"""
scale = 1.0 / (len(colors) - 1)
return [[i * scale, color] for i, color in enumerate(colors)]
def create_2d_density(
x,
y,
colorscale="Earth",
ncontours=20,
hist_color=(0, 0, 0.5),
point_color=(0, 0, 0.5),
point_size=2,
title="2D Density Plot",
height=600,
width=600,
):
"""
**deprecated**, use instead
:func:`plotly.express.density_heatmap`.
:param (list|array) x: x-axis data for plot generation
:param (list|array) y: y-axis data for plot generation
:param (str|tuple|list) colorscale: either a plotly scale name, an rgb
or hex color, a color tuple or a list or tuple of colors. An rgb
color is of the form 'rgb(x, y, z)' where x, y, z belong to the
interval [0, 255] and a color tuple is a tuple of the form
(a, b, c) where a, b and c belong to [0, 1]. If colormap is a
list, it must contain the valid color types aforementioned as its
members.
:param (int) ncontours: the number of 2D contours to draw on the plot
:param (str) hist_color: the color of the plotted histograms
:param (str) point_color: the color of the scatter points
:param (str) point_size: the color of the scatter points
:param (str) title: set the title for the plot
:param (float) height: the height of the chart
:param (float) width: the width of the chart
Examples
--------
Example 1: Simple 2D Density Plot
>>> from plotly.figure_factory import create_2d_density
>>> import numpy as np
>>> # Make data points
>>> t = np.linspace(-1,1.2,2000)
>>> x = (t**3)+(0.3*np.random.randn(2000))
>>> y = (t**6)+(0.3*np.random.randn(2000))
>>> # Create a figure
>>> fig = create_2d_density(x, y)
>>> # Plot the data
>>> fig.show()
Example 2: Using Parameters
>>> from plotly.figure_factory import create_2d_density
>>> import numpy as np
>>> # Make data points
>>> t = np.linspace(-1,1.2,2000)
>>> x = (t**3)+(0.3*np.random.randn(2000))
>>> y = (t**6)+(0.3*np.random.randn(2000))
>>> # Create custom colorscale
>>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
... (1, 1, 0.2), (0.98,0.98,0.98)]
>>> # Create a figure
>>> fig = create_2d_density(x, y, colorscale=colorscale,
... hist_color='rgb(255, 237, 222)', point_size=3)
>>> # Plot the data
>>> fig.show()
"""
# validate x and y are filled with numbers only
for array in [x, y]:
if not all(isinstance(element, Number) for element in array):
raise plotly.exceptions.PlotlyError(
"All elements of your 'x' and 'y' lists must be numbers."
)
# validate x and y are the same length
if len(x) != len(y):
raise plotly.exceptions.PlotlyError(
"Both lists 'x' and 'y' must be the same length."
)
colorscale = clrs.validate_colors(colorscale, "rgb")
colorscale = make_linear_colorscale(colorscale)
# validate hist_color and point_color
hist_color = clrs.validate_colors(hist_color, "rgb")
point_color = clrs.validate_colors(point_color, "rgb")
trace1 = graph_objs.Scatter(
x=x,
y=y,
mode="markers",
name="points",
marker=dict(color=point_color[0], size=point_size, opacity=0.4),
)
trace2 = graph_objs.Histogram2dContour(
x=x,
y=y,
name="density",
ncontours=ncontours,
colorscale=colorscale,
reversescale=True,
showscale=False,
)
trace3 = graph_objs.Histogram(
x=x, name="x density", marker=dict(color=hist_color[0]), yaxis="y2"
)
trace4 = graph_objs.Histogram(
y=y, name="y density", marker=dict(color=hist_color[0]), xaxis="x2"
)
data = [trace1, trace2, trace3, trace4]
layout = graph_objs.Layout(
showlegend=False,
autosize=False,
title=title,
height=height,
width=width,
xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
margin=dict(t=50),
hovermode="closest",
bargap=0,
xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
)
fig = graph_objs.Figure(data=data, layout=layout)
return fig
@@ -0,0 +1,69 @@
from __future__ import absolute_import
from plotly import optional_imports
# Require that numpy exists for figure_factory
np = optional_imports.get_module("numpy")
if np is None:
raise ImportError(
"""\
The figure factory module requires the numpy package"""
)
from plotly.figure_factory._2d_density import create_2d_density
from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
from plotly.figure_factory._bullet import create_bullet
from plotly.figure_factory._candlestick import create_candlestick
from plotly.figure_factory._dendrogram import create_dendrogram
from plotly.figure_factory._distplot import create_distplot
from plotly.figure_factory._facet_grid import create_facet_grid
from plotly.figure_factory._gantt import create_gantt
from plotly.figure_factory._ohlc import create_ohlc
from plotly.figure_factory._quiver import create_quiver
from plotly.figure_factory._scatterplot import create_scatterplotmatrix
from plotly.figure_factory._streamline import create_streamline
from plotly.figure_factory._table import create_table
from plotly.figure_factory._trisurf import create_trisurf
from plotly.figure_factory._violin import create_violin
if optional_imports.get_module("pandas") is not None:
from plotly.figure_factory._county_choropleth import create_choropleth
from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox
else:
def create_choropleth(*args, **kwargs):
raise ImportError("Please install pandas to use `create_choropleth`")
def create_hexbin_mapbox(*args, **kwargs):
raise ImportError("Please install pandas to use `create_hexbin_mapbox`")
if optional_imports.get_module("skimage") is not None:
from plotly.figure_factory._ternary_contour import create_ternary_contour
else:
def create_ternary_contour(*args, **kwargs):
raise ImportError("Please install scikit-image to use `create_ternary_contour`")
__all__ = [
"create_2d_density",
"create_annotated_heatmap",
"create_bullet",
"create_candlestick",
"create_choropleth",
"create_dendrogram",
"create_distplot",
"create_facet_grid",
"create_gantt",
"create_hexbin_mapbox",
"create_ohlc",
"create_quiver",
"create_scatterplotmatrix",
"create_streamline",
"create_table",
"create_ternary_contour",
"create_trisurf",
"create_violin",
]
@@ -0,0 +1,311 @@
from __future__ import absolute_import, division
import plotly.colors as clrs
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
from plotly.validators.heatmap import ColorscaleValidator
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
def validate_annotated_heatmap(z, x, y, annotation_text):
"""
Annotated-heatmap-specific validations
Check that if a text matrix is supplied, it has the same
dimensions as the z matrix.
See FigureFactory.create_annotated_heatmap() for params
:raises: (PlotlyError) If z and text matrices do not have the same
dimensions.
"""
if annotation_text is not None and isinstance(annotation_text, list):
utils.validate_equal_length(z, annotation_text)
for lst in range(len(z)):
if len(z[lst]) != len(annotation_text[lst]):
raise exceptions.PlotlyError(
"z and text should have the " "same dimensions"
)
if x:
if len(x) != len(z[0]):
raise exceptions.PlotlyError(
"oops, the x list that you "
"provided does not match the "
"width of your z matrix "
)
if y:
if len(y) != len(z):
raise exceptions.PlotlyError(
"oops, the y list that you "
"provided does not match the "
"length of your z matrix "
)
def create_annotated_heatmap(
z,
x=None,
y=None,
annotation_text=None,
colorscale="Plasma",
font_colors=None,
showscale=False,
reversescale=False,
**kwargs,
):
"""
**deprecated**, use instead
:func:`plotly.express.imshow`.
Function that creates annotated heatmaps
This function adds annotations to each cell of the heatmap.
:param (list[list]|ndarray) z: z matrix to create heatmap.
:param (list) x: x axis labels.
:param (list) y: y axis labels.
:param (list[list]|ndarray) annotation_text: Text strings for
annotations. Should have the same dimensions as the z matrix. If no
text is added, the values of the z matrix are annotated. Default =
z matrix values.
:param (list|str) colorscale: heatmap colorscale.
:param (list) font_colors: List of two color strings: [min_text_color,
max_text_color] where min_text_color is applied to annotations for
heatmap values < (max_value - min_value)/2. If font_colors is not
defined, the colors are defined logically as black or white
depending on the heatmap's colorscale.
:param (bool) showscale: Display colorscale. Default = False
:param (bool) reversescale: Reverse colorscale. Default = False
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
These kwargs describe other attributes about the annotated Heatmap
trace such as the colorscale. For more information on valid kwargs
call help(plotly.graph_objs.Heatmap)
Example 1: Simple annotated heatmap with default configuration
>>> import plotly.figure_factory as ff
>>> z = [[0.300000, 0.00000, 0.65, 0.300000],
... [1, 0.100005, 0.45, 0.4300],
... [0.300000, 0.00000, 0.65, 0.300000],
... [1, 0.100005, 0.45, 0.00000]]
>>> fig = ff.create_annotated_heatmap(z)
>>> fig.show()
"""
# Avoiding mutables in the call signature
font_colors = font_colors if font_colors is not None else []
validate_annotated_heatmap(z, x, y, annotation_text)
# validate colorscale
colorscale_validator = ColorscaleValidator()
colorscale = colorscale_validator.validate_coerce(colorscale)
annotations = _AnnotatedHeatmap(
z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
).make_annotations()
if x or y:
trace = dict(
type="heatmap",
z=z,
x=x,
y=y,
colorscale=colorscale,
showscale=showscale,
reversescale=reversescale,
**kwargs,
)
layout = dict(
annotations=annotations,
xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"),
yaxis=dict(ticks="", dtick=1, ticksuffix=" "),
)
else:
trace = dict(
type="heatmap",
z=z,
colorscale=colorscale,
showscale=showscale,
reversescale=reversescale,
**kwargs,
)
layout = dict(
annotations=annotations,
xaxis=dict(
ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False
),
yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False),
)
data = [trace]
return graph_objs.Figure(data=data, layout=layout)
def to_rgb_color_list(color_str, default):
color_str = color_str.strip()
if color_str.startswith("rgb"):
return [int(v) for v in color_str.strip("rgba()").split(",")]
elif color_str.startswith("#"):
return clrs.hex_to_rgb(color_str)
else:
return default
def should_use_black_text(background_color):
return (
background_color[0] * 0.299
+ background_color[1] * 0.587
+ background_color[2] * 0.114
) > 186
class _AnnotatedHeatmap(object):
"""
Refer to TraceFactory.create_annotated_heatmap() for docstring
"""
def __init__(
self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
):
self.z = z
if x:
self.x = x
else:
self.x = range(len(z[0]))
if y:
self.y = y
else:
self.y = range(len(z))
if annotation_text is not None:
self.annotation_text = annotation_text
else:
self.annotation_text = self.z
self.colorscale = colorscale
self.reversescale = reversescale
self.font_colors = font_colors
if np and isinstance(self.z, np.ndarray):
self.zmin = np.amin(self.z)
self.zmax = np.amax(self.z)
else:
self.zmin = min([v for row in self.z for v in row])
self.zmax = max([v for row in self.z for v in row])
if kwargs.get("zmin", None) is not None:
self.zmin = kwargs["zmin"]
if kwargs.get("zmax", None) is not None:
self.zmax = kwargs["zmax"]
self.zmid = (self.zmax + self.zmin) / 2
if kwargs.get("zmid", None) is not None:
self.zmid = kwargs["zmid"]
def get_text_color(self):
"""
Get font color for annotations.
The annotated heatmap can feature two text colors: min_text_color and
max_text_color. The min_text_color is applied to annotations for
heatmap values < (max_value - min_value)/2. The user can define these
two colors. Otherwise the colors are defined logically as black or
white depending on the heatmap's colorscale.
:rtype (string, string) min_text_color, max_text_color: text
color for annotations for heatmap values <
(max_value - min_value)/2 and text color for annotations for
heatmap values >= (max_value - min_value)/2
"""
# Plotly colorscales ranging from a lighter shade to a darker shade
colorscales = [
"Greys",
"Greens",
"Blues",
"YIGnBu",
"YIOrRd",
"RdBu",
"Picnic",
"Jet",
"Hot",
"Blackbody",
"Earth",
"Electric",
"Viridis",
"Cividis",
]
# Plotly colorscales ranging from a darker shade to a lighter shade
colorscales_reverse = ["Reds"]
white = "#FFFFFF"
black = "#000000"
if self.font_colors:
min_text_color = self.font_colors[0]
max_text_color = self.font_colors[-1]
elif self.colorscale in colorscales and self.reversescale:
min_text_color = black
max_text_color = white
elif self.colorscale in colorscales:
min_text_color = white
max_text_color = black
elif self.colorscale in colorscales_reverse and self.reversescale:
min_text_color = white
max_text_color = black
elif self.colorscale in colorscales_reverse:
min_text_color = black
max_text_color = white
elif isinstance(self.colorscale, list):
min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255])
max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255])
# swap min/max colors if reverse scale
if self.reversescale:
min_col, max_col = max_col, min_col
if should_use_black_text(min_col):
min_text_color = black
else:
min_text_color = white
if should_use_black_text(max_col):
max_text_color = black
else:
max_text_color = white
else:
min_text_color = black
max_text_color = black
return min_text_color, max_text_color
def make_annotations(self):
"""
Get annotations for each cell of the heatmap with graph_objs.Annotation
:rtype (list[dict]) annotations: list of annotations for each cell of
the heatmap
"""
min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
annotations = []
for n, row in enumerate(self.z):
for m, val in enumerate(row):
font_color = min_text_color if val < self.zmid else max_text_color
annotations.append(
graph_objs.layout.Annotation(
text=str(self.annotation_text[n][m]),
x=self.x[m],
y=self.y[n],
xref="x1",
yref="y1",
font=dict(color=font_color),
showarrow=False,
)
)
return annotations
@@ -0,0 +1,369 @@
from __future__ import absolute_import
import collections
import math
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.figure_factory import utils
import plotly
import plotly.graph_objs as go
pd = optional_imports.get_module("pandas")
def _bullet(
df,
markers,
measures,
ranges,
subtitles,
titles,
orientation,
range_colors,
measure_colors,
horizontal_spacing,
vertical_spacing,
scatter_options,
layout_options,
):
num_of_lanes = len(df)
num_of_rows = num_of_lanes if orientation == "h" else 1
num_of_cols = 1 if orientation == "h" else num_of_lanes
if not horizontal_spacing:
horizontal_spacing = 1.0 / num_of_lanes
if not vertical_spacing:
vertical_spacing = 1.0 / num_of_lanes
fig = plotly.subplots.make_subplots(
num_of_rows,
num_of_cols,
print_grid=False,
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
)
# layout
fig["layout"].update(
dict(shapes=[]),
title="Bullet Chart",
height=600,
width=1000,
showlegend=False,
barmode="stack",
annotations=[],
margin=dict(l=120 if orientation == "h" else 80),
)
# update layout
fig["layout"].update(layout_options)
if orientation == "h":
width_axis = "yaxis"
length_axis = "xaxis"
else:
width_axis = "xaxis"
length_axis = "yaxis"
for key in fig["layout"]:
if "xaxis" in key or "yaxis" in key:
fig["layout"][key]["showgrid"] = False
fig["layout"][key]["zeroline"] = False
if length_axis in key:
fig["layout"][key]["tickwidth"] = 1
if width_axis in key:
fig["layout"][key]["showticklabels"] = False
fig["layout"][key]["range"] = [0, 1]
# narrow domain if 1 bar
if num_of_lanes <= 1:
fig["layout"][width_axis + "1"]["domain"] = [0.4, 0.6]
if not range_colors:
range_colors = ["rgb(200, 200, 200)", "rgb(245, 245, 245)"]
if not measure_colors:
measure_colors = ["rgb(31, 119, 180)", "rgb(176, 196, 221)"]
for row in range(num_of_lanes):
# ranges bars
for idx in range(len(df.iloc[row]["ranges"])):
inter_colors = clrs.n_colors(
range_colors[0], range_colors[1], len(df.iloc[row]["ranges"]), "rgb"
)
x = (
[sorted(df.iloc[row]["ranges"])[-1 - idx]]
if orientation == "h"
else [0]
)
y = (
[0]
if orientation == "h"
else [sorted(df.iloc[row]["ranges"])[-1 - idx]]
)
bar = go.Bar(
x=x,
y=y,
marker=dict(color=inter_colors[-1 - idx]),
name="ranges",
hoverinfo="x" if orientation == "h" else "y",
orientation=orientation,
width=2,
base=0,
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
)
fig.add_trace(bar)
# measures bars
for idx in range(len(df.iloc[row]["measures"])):
inter_colors = clrs.n_colors(
measure_colors[0],
measure_colors[1],
len(df.iloc[row]["measures"]),
"rgb",
)
x = (
[sorted(df.iloc[row]["measures"])[-1 - idx]]
if orientation == "h"
else [0.5]
)
y = (
[0.5]
if orientation == "h"
else [sorted(df.iloc[row]["measures"])[-1 - idx]]
)
bar = go.Bar(
x=x,
y=y,
marker=dict(color=inter_colors[-1 - idx]),
name="measures",
hoverinfo="x" if orientation == "h" else "y",
orientation=orientation,
width=0.4,
base=0,
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
)
fig.add_trace(bar)
# markers
x = df.iloc[row]["markers"] if orientation == "h" else [0.5]
y = [0.5] if orientation == "h" else df.iloc[row]["markers"]
markers = go.Scatter(
x=x,
y=y,
name="markers",
hoverinfo="x" if orientation == "h" else "y",
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
**scatter_options,
)
fig.add_trace(markers)
# titles and subtitles
title = df.iloc[row]["titles"]
if "subtitles" in df:
subtitle = "<br>{}".format(df.iloc[row]["subtitles"])
else:
subtitle = ""
label = "<b>{}</b>".format(title) + subtitle
annot = utils.annotation_dict_for_label(
label,
(num_of_lanes - row if orientation == "h" else row + 1),
num_of_lanes,
vertical_spacing if orientation == "h" else horizontal_spacing,
"row" if orientation == "h" else "col",
True if orientation == "h" else False,
False,
)
fig["layout"]["annotations"] += (annot,)
return fig
def create_bullet(
data,
markers=None,
measures=None,
ranges=None,
subtitles=None,
titles=None,
orientation="h",
range_colors=("rgb(200, 200, 200)", "rgb(245, 245, 245)"),
measure_colors=("rgb(31, 119, 180)", "rgb(176, 196, 221)"),
horizontal_spacing=None,
vertical_spacing=None,
scatter_options={},
**layout_options,
):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Indicator`.
:param (pd.DataFrame | list | tuple) data: either a list/tuple of
dictionaries or a pandas DataFrame.
:param (str) markers: the column name or dictionary key for the markers in
each subplot.
:param (str) measures: the column name or dictionary key for the measure
bars in each subplot. This bar usually represents the quantitative
measure of performance, usually a list of two values [a, b] and are
the blue bars in the foreground of each subplot by default.
:param (str) ranges: the column name or dictionary key for the qualitative
ranges of performance, usually a 3-item list [bad, okay, good]. They
correspond to the grey bars in the background of each chart.
:param (str) subtitles: the column name or dictionary key for the subtitle
of each subplot chart. The subplots are displayed right underneath
each title.
:param (str) titles: the column name or dictionary key for the main label
of each subplot chart.
:param (bool) orientation: if 'h', the bars are placed horizontally as
rows. If 'v' the bars are placed vertically in the chart.
:param (list) range_colors: a tuple of two colors between which all
the rectangles for the range are drawn. These rectangles are meant to
be qualitative indicators against which the marker and measure bars
are compared.
Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)')
:param (list) measure_colors: a tuple of two colors which is used to color
the thin quantitative bars in the bullet chart.
Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)')
:param (float) horizontal_spacing: see the 'horizontal_spacing' param in
plotly.tools.make_subplots. Ranges between 0 and 1.
:param (float) vertical_spacing: see the 'vertical_spacing' param in
plotly.tools.make_subplots. Ranges between 0 and 1.
:param (dict) scatter_options: describes attributes for the scatter trace
in each subplot such as name and marker size. Call
help(plotly.graph_objs.Scatter) for more information on valid params.
:param layout_options: describes attributes for the layout of the figure
such as title, height and width. Call help(plotly.graph_objs.Layout)
for more information on valid params.
Example 1: Use a Dictionary
>>> import plotly.figure_factory as ff
>>> data = [
... {"label": "revenue", "sublabel": "us$, in thousands",
... "range": [150, 225, 300], "performance": [220,270], "point": [250]},
... {"label": "Profit", "sublabel": "%", "range": [20, 25, 30],
... "performance": [21, 23], "point": [26]},
... {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600],
... "performance": [100,320],"point": [550]},
... {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500],
... "performance": [1000, 1650],"point": [2100]},
... {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5],
... "performance": [3.2, 4.7], "point": [4.4]}
... ]
>>> fig = ff.create_bullet(
... data, titles='label', subtitles='sublabel', markers='point',
... measures='performance', ranges='range', orientation='h',
... title='my simple bullet chart'
... )
>>> fig.show()
Example 2: Use a DataFrame with Custom Colors
>>> import plotly.figure_factory as ff
>>> import pandas as pd
>>> data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json')
>>> fig = ff.create_bullet(
... data, titles='title', markers='markers', measures='measures',
... orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'],
... scatter_options={'marker': {'symbol': 'circle'}}, width=700)
>>> fig.show()
"""
# validate df
if not pd:
raise ImportError("'pandas' must be installed for this figure factory.")
if utils.is_sequence(data):
if not all(isinstance(item, dict) for item in data):
raise exceptions.PlotlyError(
"Every entry of the data argument list, tuple, etc must "
"be a dictionary."
)
elif not isinstance(data, pd.DataFrame):
raise exceptions.PlotlyError(
"You must input a pandas DataFrame, or a list of dictionaries."
)
# make DataFrame from data with correct column headers
col_names = ["titles", "subtitle", "markers", "measures", "ranges"]
if utils.is_sequence(data):
df = pd.DataFrame(
[
[d[titles] for d in data] if titles else [""] * len(data),
[d[subtitles] for d in data] if subtitles else [""] * len(data),
[d[markers] for d in data] if markers else [[]] * len(data),
[d[measures] for d in data] if measures else [[]] * len(data),
[d[ranges] for d in data] if ranges else [[]] * len(data),
],
index=col_names,
)
elif isinstance(data, pd.DataFrame):
df = pd.DataFrame(
[
data[titles].tolist() if titles else [""] * len(data),
data[subtitles].tolist() if subtitles else [""] * len(data),
data[markers].tolist() if markers else [[]] * len(data),
data[measures].tolist() if measures else [[]] * len(data),
data[ranges].tolist() if ranges else [[]] * len(data),
],
index=col_names,
)
df = pd.DataFrame.transpose(df)
# make sure ranges, measures, 'markers' are not NAN or NONE
for needed_key in ["ranges", "measures", "markers"]:
for idx, r in enumerate(df[needed_key]):
try:
r_is_nan = math.isnan(r)
if r_is_nan or r is None:
df[needed_key][idx] = []
except TypeError:
pass
# validate custom colors
for colors_list in [range_colors, measure_colors]:
if colors_list:
if len(colors_list) != 2:
raise exceptions.PlotlyError(
"Both 'range_colors' or 'measure_colors' must be a list "
"of two valid colors."
)
clrs.validate_colors(colors_list)
colors_list = clrs.convert_colors_to_same_type(colors_list, "rgb")[0]
# default scatter options
default_scatter = {
"marker": {"size": 12, "symbol": "diamond-tall", "color": "rgb(0, 0, 0)"}
}
if scatter_options == {}:
scatter_options.update(default_scatter)
else:
# add default options to scatter_options if they are not present
for k in default_scatter["marker"]:
if k not in scatter_options["marker"]:
scatter_options["marker"][k] = default_scatter["marker"][k]
fig = _bullet(
df,
markers,
measures,
ranges,
subtitles,
titles,
orientation,
range_colors,
measure_colors,
horizontal_spacing,
vertical_spacing,
scatter_options,
layout_options,
)
return fig
@@ -0,0 +1,279 @@
from __future__ import absolute_import
from plotly.figure_factory import utils
from plotly.figure_factory._ohlc import (
_DEFAULT_INCREASING_COLOR,
_DEFAULT_DECREASING_COLOR,
validate_ohlc,
)
from plotly.graph_objs import graph_objs
def make_increasing_candle(open, high, low, close, dates, **kwargs):
"""
Makes boxplot trace for increasing candlesticks
_make_increasing_candle() and _make_decreasing_candle separate the
increasing traces from the decreasing traces so kwargs (such as
color) can be passed separately to increasing or decreasing traces
when direction is set to 'increasing' or 'decreasing' in
FigureFactory.create_candlestick()
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (list) candle_incr_data: list of the box trace for
increasing candlesticks.
"""
increase_x, increase_y = _Candlestick(
open, high, low, close, dates, **kwargs
).get_candle_increase()
if "line" in kwargs:
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
else:
kwargs.setdefault("fillcolor", _DEFAULT_INCREASING_COLOR)
if "name" in kwargs:
kwargs.setdefault("showlegend", True)
else:
kwargs.setdefault("showlegend", False)
kwargs.setdefault("name", "Increasing")
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR))
candle_incr_data = dict(
type="box",
x=increase_x,
y=increase_y,
whiskerwidth=0,
boxpoints=False,
**kwargs,
)
return [candle_incr_data]
def make_decreasing_candle(open, high, low, close, dates, **kwargs):
"""
Makes boxplot trace for decreasing candlesticks
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to decreasing trace via
plotly.graph_objs.Scatter.
:rtype (list) candle_decr_data: list of the box trace for
decreasing candlesticks.
"""
decrease_x, decrease_y = _Candlestick(
open, high, low, close, dates, **kwargs
).get_candle_decrease()
if "line" in kwargs:
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
else:
kwargs.setdefault("fillcolor", _DEFAULT_DECREASING_COLOR)
kwargs.setdefault("showlegend", False)
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR))
kwargs.setdefault("name", "Decreasing")
candle_decr_data = dict(
type="box",
x=decrease_x,
y=decrease_y,
whiskerwidth=0,
boxpoints=False,
**kwargs,
)
return [candle_decr_data]
def create_candlestick(open, high, low, close, dates=None, direction="both", **kwargs):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Candlestick`
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param (string) direction: direction can be 'increasing', 'decreasing',
or 'both'. When the direction is 'increasing', the returned figure
consists of all candlesticks where the close value is greater than
the corresponding open value, and when the direction is
'decreasing', the returned figure consists of all candlesticks
where the close value is less than or equal to the corresponding
open value. When the direction is 'both', both increasing and
decreasing candlesticks are returned. Default: 'both'
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
These kwargs describe other attributes about the ohlc Scatter trace
such as the color or the legend name. For more information on valid
kwargs call help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of candlestick chart figure.
Example 1: Simple candlestick chart from a Pandas DataFrame
>>> from plotly.figure_factory import create_candlestick
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index)
>>> fig.show()
Example 2: Customize the candlestick colors
>>> from plotly.figure_factory import create_candlestick
>>> from plotly.graph_objs import Line, Marker
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> # Make increasing candlesticks and customize their color and name
>>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index,
... direction='increasing', name='AAPL',
... marker=Marker(color='rgb(150, 200, 250)'),
... line=Line(color='rgb(150, 200, 250)'))
>>> # Make decreasing candlesticks and customize their color and name
>>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index,
... direction='decreasing',
... marker=Marker(color='rgb(128, 128, 128)'),
... line=Line(color='rgb(128, 128, 128)'))
>>> # Initialize the figure
>>> fig = fig_increasing
>>> # Add decreasing data with .extend()
>>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP
>>> fig.show()
Example 3: Candlestick chart with datetime objects
>>> from plotly.figure_factory import create_candlestick
>>> from datetime import datetime
>>> # Add data
>>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
>>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
>>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
>>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
>>> dates = [datetime(year=2013, month=10, day=10),
... datetime(year=2013, month=11, day=10),
... datetime(year=2013, month=12, day=10),
... datetime(year=2014, month=1, day=10),
... datetime(year=2014, month=2, day=10)]
>>> # Create ohlc
>>> fig = create_candlestick(open_data, high_data,
... low_data, close_data, dates=dates)
>>> fig.show()
"""
if dates is not None:
utils.validate_equal_length(open, high, low, close, dates)
else:
utils.validate_equal_length(open, high, low, close)
validate_ohlc(open, high, low, close, direction, **kwargs)
if direction == "increasing":
candle_incr_data = make_increasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_incr_data
elif direction == "decreasing":
candle_decr_data = make_decreasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_decr_data
else:
candle_incr_data = make_increasing_candle(
open, high, low, close, dates, **kwargs
)
candle_decr_data = make_decreasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_incr_data + candle_decr_data
layout = graph_objs.Layout()
return graph_objs.Figure(data=data, layout=layout)
class _Candlestick(object):
"""
Refer to FigureFactory.create_candlestick() for docstring.
"""
def __init__(self, open, high, low, close, dates, **kwargs):
self.open = open
self.high = high
self.low = low
self.close = close
if dates is not None:
self.x = dates
else:
self.x = [x for x in range(len(self.open))]
self.get_candle_increase()
def get_candle_increase(self):
"""
Separate increasing data from decreasing data.
The data is increasing when close value > open value
and decreasing when the close value <= open value.
"""
increase_y = []
increase_x = []
for index in range(len(self.open)):
if self.close[index] > self.open[index]:
increase_y.append(self.low[index])
increase_y.append(self.open[index])
increase_y.append(self.close[index])
increase_y.append(self.close[index])
increase_y.append(self.close[index])
increase_y.append(self.high[index])
increase_x.append(self.x[index])
increase_x = [[x, x, x, x, x, x] for x in increase_x]
increase_x = utils.flatten(increase_x)
return increase_x, increase_y
def get_candle_decrease(self):
"""
Separate increasing data from decreasing data.
The data is increasing when close value > open value
and decreasing when the close value <= open value.
"""
decrease_y = []
decrease_x = []
for index in range(len(self.open)):
if self.close[index] <= self.open[index]:
decrease_y.append(self.low[index])
decrease_y.append(self.open[index])
decrease_y.append(self.close[index])
decrease_y.append(self.close[index])
decrease_y.append(self.close[index])
decrease_y.append(self.high[index])
decrease_x.append(self.x[index])
decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
decrease_x = utils.flatten(decrease_x)
return decrease_x, decrease_y
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,399 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
def create_dendrogram(
X,
orientation="bottom",
labels=None,
colorscale=None,
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
):
"""
Function that returns a dendrogram Plotly figure object. This is a thin
wrapper around scipy.cluster.hierarchy.dendrogram.
See also https://dash.plot.ly/dash-bio/clustergram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for the dendrogram tree.
Requires 8 colors to be specified, the 7th of
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
and 6th are used twice as often as the others.
Given a shorter list, the missing values are
replaced with defaults and with a longer list the
extra values are ignored.
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
clusters
:param (double) color_threshold: Value at which the separation of clusters will be made
Example 1: Simple bottom oriented dendrogram
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(10,10)
>>> fig = create_dendrogram(X)
>>> fig.show()
Example 2: Dendrogram to put on the left of the heatmap
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(5,5)
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
>>> dendro.show()
Example 3: Dendrogram with Pandas
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> import pandas as pd
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
>>> fig = create_dendrogram(df, labels=Index)
>>> fig.show()
"""
if not scp or not scs or not sch:
raise ImportError(
"FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy"
)
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(
X,
orientation,
labels,
colorscale,
distfun=distfun,
linkagefun=linkagefun,
hovertext=hovertext,
color_threshold=color_threshold,
)
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(
self,
X,
orientation="bottom",
labels=None,
colorscale=None,
width=np.inf,
height=np.inf,
xaxis="xaxis",
yaxis="yaxis",
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ["left", "bottom"]:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ["right", "bottom"]:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
X, colorscale, distfun, linkagefun, hovertext, color_threshold
)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
if len(self.zero_vals) > len(yvals) + 1:
# If the length of zero_vals is larger than the length of yvals,
# it means that there are wrong vals because of the identicial samples.
# Three and more identicial samples will make the yvals of spliting
# center into 0 and it will accidentally take it as leaves.
l_border = int(min(self.zero_vals))
r_border = int(max(self.zero_vals))
correct_leaves_pos = range(
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
)
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
self.zero_vals = [v for v in correct_leaves_pos]
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = dd_traces
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
# This list is the colors that can be used by dendrogram, which were
# determined as the combination of the default above_threshold_color and
# the default color palette (see scipy/cluster/hierarchy.py)
d = {
"r": "red",
"g": "green",
"b": "blue",
"c": "cyan",
"m": "magenta",
"y": "yellow",
"k": "black",
# TODO: 'w' doesn't seem to be in the default color
# palette in scipy/cluster/hierarchy.py
"w": "white",
}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
rgb_colorscale = [
"rgb(0,116,217)", # blue
"rgb(35,205,205)", # cyan
"rgb(61,153,112)", # green
"rgb(40,35,35)", # black
"rgb(133,20,75)", # magenta
"rgb(255,65,54)", # red
"rgb(255,255,255)", # white
"rgb(255,220,0)", # yellow
]
else:
rgb_colorscale = colorscale
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(rgb_colorscale):
default_colors[k] = rgb_colorscale[i]
# add support for cyclic format colors as introduced in scipy===1.5.0
# before this, the colors were named 'r', 'b', 'y' etc., now they are
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
# scipy version, we try as much as possible to map the new colors to the
# old colors
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
# comment above).
new_old_color_map = [
("C0", "b"),
("C1", "g"),
("C2", "r"),
("C3", "c"),
("C4", "m"),
("C5", "y"),
("C6", "k"),
("C7", "g"),
("C8", "r"),
("C9", "c"),
]
for nc, oc in new_old_color_map:
try:
default_colors[nc] = default_colors[oc]
except KeyError:
# it could happen that the old color isn't found (if a custom
# colorscale was specified), in this case we set it to an
# arbitrary default.
default_colors[n] = "rgb(0,116,217)"
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
"type": "linear",
"ticks": "outside",
"mirror": "allticks",
"rangemode": "tozero",
"showticklabels": True,
"zeroline": False,
"showgrid": False,
"showline": True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ["left", "right"]:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]["tickvals"] = [
zv * self.sign[axis_key] for zv in self.zero_vals
]
self.layout[axis_key_labels]["ticktext"] = self.labels
self.layout[axis_key_labels]["tickmode"] = "array"
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update(
{
"showlegend": False,
"autosize": False,
"hovermode": "closest",
"width": width,
"height": height,
}
)
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
)
icoord = scp.array(P["icoord"])
dcoord = scp.array(P["dcoord"])
ordered_labels = scp.array(P["ivl"])
color_list = scp.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = "x" + x_index
trace["yaxis"] = "y" + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
@@ -0,0 +1,449 @@
from __future__ import absolute_import
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
pd = optional_imports.get_module("pandas")
scipy = optional_imports.get_module("scipy")
scipy_stats = optional_imports.get_module("scipy.stats")
DEFAULT_HISTNORM = "probability density"
ALTERNATIVE_HISTNORM = "probability"
def validate_distplot(hist_data, curve_type):
"""
Distplot-specific validations
:raises: (PlotlyError) If hist_data is not a list of lists
:raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
'normal').
"""
hist_data_types = (list,)
if np:
hist_data_types += (np.ndarray,)
if pd:
hist_data_types += (pd.core.series.Series,)
if not isinstance(hist_data[0], hist_data_types):
raise exceptions.PlotlyError(
"Oops, this function was written "
"to handle multiple datasets, if "
"you want to plot just one, make "
"sure your hist_data variable is "
"still a list of lists, i.e. x = "
"[1, 2, 3] -> x = [[1, 2, 3]]"
)
curve_opts = ("kde", "normal")
if curve_type not in curve_opts:
raise exceptions.PlotlyError(
"curve_type must be defined as " "'kde' or 'normal'"
)
if not scipy:
raise ImportError("FigureFactory.create_distplot requires scipy")
def create_distplot(
hist_data,
group_labels,
bin_size=1.0,
curve_type="kde",
colors=None,
rug_text=None,
histnorm=DEFAULT_HISTNORM,
show_hist=True,
show_curve=True,
show_rug=True,
):
"""
Function that creates a distplot similar to seaborn.distplot;
**this function is deprecated**, use instead :mod:`plotly.express`
functions, for example
>>> import plotly.express as px
>>> tips = px.data.tips()
>>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
... hover_data=tips.columns)
>>> fig.show()
The distplot can be composed of all or any combination of the following
3 components: (1) histogram, (2) curve: (a) kernel density estimation
or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
(from multiple datasets) can be created in the same plot.
:param (list[list]) hist_data: Use list of lists to plot multiple data
sets on the same plot.
:param (list[str]) group_labels: Names for each data set.
:param (list[float]|float) bin_size: Size of histogram bins.
Default = 1.
:param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
:param (str) histnorm: 'probability density' or 'probability'
Default = 'probability density'
:param (bool) show_hist: Add histogram to distplot? Default = True
:param (bool) show_curve: Add curve to distplot? Default = True
:param (bool) show_rug: Add rug to distplot? Default = True
:param (list[str]) colors: Colors for traces.
:param (list[list]) rug_text: Hovertext values for rug_plot,
:return (dict): Representation of a distplot figure.
Example 1: Simple distplot of 1 data set
>>> from plotly.figure_factory import create_distplot
>>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
... 3.5, 4.1, 4.4, 4.5, 4.5,
... 5.0, 5.0, 5.2, 5.5, 5.5,
... 5.5, 5.5, 5.5, 6.1, 7.0]]
>>> group_labels = ['distplot example']
>>> fig = create_distplot(hist_data, group_labels)
>>> fig.show()
Example 2: Two data sets and added rug text
>>> from plotly.figure_factory import create_distplot
>>> # Add histogram data
>>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
... -0.9, -0.07, 1.95, 0.9, -0.2,
... -0.5, 0.3, 0.4, -0.37, 0.6]
>>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
... 1.0, 0.8, 1.7, 0.5, 0.8,
... -0.3, 1.2, 0.56, 0.3, 2.2]
>>> # Group data together
>>> hist_data = [hist1_x, hist2_x]
>>> group_labels = ['2012', '2013']
>>> # Add text
>>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
... 'f1', 'g1', 'h1', 'i1', 'j1',
... 'k1', 'l1', 'm1', 'n1', 'o1']
>>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
... 'f2', 'g2', 'h2', 'i2', 'j2',
... 'k2', 'l2', 'm2', 'n2', 'o2']
>>> # Group text together
>>> rug_text_all = [rug_text_1, rug_text_2]
>>> # Create distplot
>>> fig = create_distplot(
... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
>>> # Add title
>>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
>>> fig.show()
Example 3: Plot with normal curve and hide rug plot
>>> from plotly.figure_factory import create_distplot
>>> import numpy as np
>>> x1 = np.random.randn(190)
>>> x2 = np.random.randn(200)+1
>>> x3 = np.random.randn(200)-1
>>> x4 = np.random.randn(210)+2
>>> hist_data = [x1, x2, x3, x4]
>>> group_labels = ['2012', '2013', '2014', '2015']
>>> fig = create_distplot(
... hist_data, group_labels, curve_type='normal',
... show_rug=False, bin_size=.4)
Example 4: Distplot with Pandas
>>> from plotly.figure_factory import create_distplot
>>> import numpy as np
>>> import pandas as pd
>>> df = pd.DataFrame({'2012': np.random.randn(200),
... '2013': np.random.randn(200)+1})
>>> fig = create_distplot([df[c] for c in df.columns], df.columns)
>>> fig.show()
"""
if colors is None:
colors = []
if rug_text is None:
rug_text = []
validate_distplot(hist_data, curve_type)
utils.validate_equal_length(hist_data, group_labels)
if isinstance(bin_size, (float, int)):
bin_size = [bin_size] * len(hist_data)
data = []
if show_hist:
hist = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_hist()
data.append(hist)
if show_curve:
if curve_type == "normal":
curve = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_normal()
else:
curve = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_kde()
data.append(curve)
if show_rug:
rug = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_rug()
data.append(rug)
layout = graph_objs.Layout(
barmode="overlay",
hovermode="closest",
legend=dict(traceorder="reversed"),
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
)
else:
layout = graph_objs.Layout(
barmode="overlay",
hovermode="closest",
legend=dict(traceorder="reversed"),
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
)
data = sum(data, [])
return graph_objs.Figure(data=data, layout=layout)
class _Distplot(object):
"""
Refer to TraceFactory.create_distplot() for docstring
"""
def __init__(
self,
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
):
self.hist_data = hist_data
self.histnorm = histnorm
self.group_labels = group_labels
self.bin_size = bin_size
self.show_hist = show_hist
self.show_curve = show_curve
self.trace_number = len(hist_data)
if rug_text:
self.rug_text = rug_text
else:
self.rug_text = [None] * self.trace_number
self.start = []
self.end = []
if colors:
self.colors = colors
else:
self.colors = [
"rgb(31, 119, 180)",
"rgb(255, 127, 14)",
"rgb(44, 160, 44)",
"rgb(214, 39, 40)",
"rgb(148, 103, 189)",
"rgb(140, 86, 75)",
"rgb(227, 119, 194)",
"rgb(127, 127, 127)",
"rgb(188, 189, 34)",
"rgb(23, 190, 207)",
]
self.curve_x = [None] * self.trace_number
self.curve_y = [None] * self.trace_number
for trace in self.hist_data:
self.start.append(min(trace) * 1.0)
self.end.append(max(trace) * 1.0)
def make_hist(self):
"""
Makes the histogram(s) for FigureFactory.create_distplot().
:rtype (list) hist: list of histogram representations
"""
hist = [None] * self.trace_number
for index in range(self.trace_number):
hist[index] = dict(
type="histogram",
x=self.hist_data[index],
xaxis="x1",
yaxis="y1",
histnorm=self.histnorm,
name=self.group_labels[index],
legendgroup=self.group_labels[index],
marker=dict(color=self.colors[index % len(self.colors)]),
autobinx=False,
xbins=dict(
start=self.start[index],
end=self.end[index],
size=self.bin_size[index],
),
opacity=0.7,
)
return hist
def make_kde(self):
"""
Makes the kernel density estimation(s) for create_distplot().
This is called when curve_type = 'kde' in create_distplot().
:rtype (list) curve: list of kde representations
"""
curve = [None] * self.trace_number
for index in range(self.trace_number):
self.curve_x[index] = [
self.start[index] + x * (self.end[index] - self.start[index]) / 500
for x in range(500)
]
self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
self.curve_x[index]
)
if self.histnorm == ALTERNATIVE_HISTNORM:
self.curve_y[index] *= self.bin_size[index]
for index in range(self.trace_number):
curve[index] = dict(
type="scatter",
x=self.curve_x[index],
y=self.curve_y[index],
xaxis="x1",
yaxis="y1",
mode="lines",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=False if self.show_hist else True,
marker=dict(color=self.colors[index % len(self.colors)]),
)
return curve
def make_normal(self):
"""
Makes the normal curve(s) for create_distplot().
This is called when curve_type = 'normal' in create_distplot().
:rtype (list) curve: list of normal curve representations
"""
curve = [None] * self.trace_number
mean = [None] * self.trace_number
sd = [None] * self.trace_number
for index in range(self.trace_number):
mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
self.curve_x[index] = [
self.start[index] + x * (self.end[index] - self.start[index]) / 500
for x in range(500)
]
self.curve_y[index] = scipy_stats.norm.pdf(
self.curve_x[index], loc=mean[index], scale=sd[index]
)
if self.histnorm == ALTERNATIVE_HISTNORM:
self.curve_y[index] *= self.bin_size[index]
for index in range(self.trace_number):
curve[index] = dict(
type="scatter",
x=self.curve_x[index],
y=self.curve_y[index],
xaxis="x1",
yaxis="y1",
mode="lines",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=False if self.show_hist else True,
marker=dict(color=self.colors[index % len(self.colors)]),
)
return curve
def make_rug(self):
"""
Makes the rug plot(s) for create_distplot().
:rtype (list) rug: list of rug plot representations
"""
rug = [None] * self.trace_number
for index in range(self.trace_number):
rug[index] = dict(
type="scatter",
x=self.hist_data[index],
y=([self.group_labels[index]] * len(self.hist_data[index])),
xaxis="x1",
yaxis="y2",
mode="markers",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=(False if self.show_hist or self.show_curve else True),
text=self.rug_text[index],
marker=dict(
color=self.colors[index % len(self.colors)], symbol="line-ns-open"
),
)
return rug
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,492 @@
from plotly.express._core import build_dataframe
from plotly.express._doc import make_docstring
from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
import numpy as np
import pandas as pd
def _project_latlon_to_wgs84(lat, lon):
"""
Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
"""
x = lon * np.pi / 180
y = np.arctanh(np.sin(lat * np.pi / 180))
return x, y
def _project_wgs84_to_latlon(x, y):
"""
Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
"""
lon = x * 180 / np.pi
lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
return lat, lon
def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim):
"""
Get the mapbox zoom level given bounds and a figure dimension
Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds
"""
scale = (
2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256
)
WORLD_DIM = {"height": 256 * scale, "width": 256 * scale}
ZOOM_MAX = 18
def latRad(lat):
sin = np.sin(lat * np.pi / 180)
radX2 = np.log((1 + sin) / (1 - sin)) / 2
return max(min(radX2, np.pi), -np.pi) / 2
def zoom(mapPx, worldPx, fraction):
return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2)
latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi
lngDiff = lon_max - lon_min
lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360
latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction)
lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction)
return min(latZoom, lngZoom, ZOOM_MAX)
def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count):
"""
Computes the aggregation at hexagonal bin level.
Also defines the coordinates of the hexagons for plotting.
The binning is inspired by matplotlib's implementation.
Parameters
----------
x : np.ndarray
Array of x values (shape N)
y : np.ndarray
Array of y values (shape N)
x_range : np.ndarray
Min and max x (shape 2)
y_range : np.ndarray
Min and max y (shape 2)
color : np.ndarray
Metric to aggregate at hexagon level (shape N)
nx : int
Number of hexagons horizontally
agg_func : function
Numpy compatible aggregator, this function must take a one-dimensional
np.ndarray as input and output a scalar
min_count : int
Minimum number of points in the hexagon for the hexagon to be displayed
Returns
-------
np.ndarray
X coordinates of each hexagon (shape M x 6)
np.ndarray
Y coordinates of each hexagon (shape M x 6)
np.ndarray
Centers of the hexagons (shape M x 2)
np.ndarray
Aggregated value in each hexagon (shape M)
"""
xmin = x_range.min()
xmax = x_range.max()
ymin = y_range.min()
ymax = y_range.max()
# In the x-direction, the hexagons exactly cover the region from
# xmin to xmax. Need some padding to avoid roundoff errors.
padding = 1.0e-9 * (xmax - xmin)
xmin -= padding
xmax += padding
Dx = xmax - xmin
Dy = ymax - ymin
if Dx == 0 and Dy > 0:
dx = Dy / nx
elif Dx == 0 and Dy == 0:
dx, _ = _project_latlon_to_wgs84(1, 1)
else:
dx = Dx / nx
dy = dx * np.sqrt(3)
ny = np.ceil(Dy / dy).astype(int)
# Center the hexagons vertically since we only want regular hexagons
ymin -= (ymin + dy * ny - ymax) / 2
x = (x - xmin) / dx
y = (y - ymin) / dy
ix1 = np.round(x).astype(int)
iy1 = np.round(y).astype(int)
ix2 = np.floor(x).astype(int)
iy2 = np.floor(y).astype(int)
nx1 = nx + 1
ny1 = ny + 1
nx2 = nx
ny2 = ny
n = nx1 * ny1 + nx2 * ny2
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
bdist = d1 < d2
if color is None:
lattice1 = np.zeros((nx1, ny1))
lattice2 = np.zeros((nx2, ny2))
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
if min_count is not None:
lattice1[lattice1 < min_count] = np.nan
lattice2[lattice2 < min_count] = np.nan
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
good_idxs = ~np.isnan(accum)
else:
if min_count is None:
min_count = 1
# create accumulation arrays
lattice1 = np.empty((nx1, ny1), dtype=object)
for i in range(nx1):
for j in range(ny1):
lattice1[i, j] = []
lattice2 = np.empty((nx2, ny2), dtype=object)
for i in range(nx2):
for j in range(ny2):
lattice2[i, j] = []
for i in range(len(x)):
if bdist[i]:
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
lattice1[ix1[i], iy1[i]].append(color[i])
else:
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
lattice2[ix2[i], iy2[i]].append(color[i])
for i in range(nx1):
for j in range(ny1):
vals = lattice1[i, j]
if len(vals) >= min_count:
lattice1[i, j] = agg_func(vals)
else:
lattice1[i, j] = np.nan
for i in range(nx2):
for j in range(ny2):
vals = lattice2[i, j]
if len(vals) >= min_count:
lattice2[i, j] = agg_func(vals)
else:
lattice2[i, j] = np.nan
accum = np.hstack(
(lattice1.astype(float).ravel(), lattice2.astype(float).ravel())
)
good_idxs = ~np.isnan(accum)
agreggated_value = accum[good_idxs]
centers = np.zeros((n, 2), float)
centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5
centers[:, 0] *= dx
centers[:, 1] *= dy
centers[:, 0] += xmin
centers[:, 1] += ymin
centers = centers[good_idxs]
# Define normalised regular hexagon coordinates
hx = [0, 0.5, 0.5, 0, -0.5, -0.5]
hy = [
-0.5 / np.cos(np.pi / 6),
-0.5 * np.tan(np.pi / 6),
0.5 * np.tan(np.pi / 6),
0.5 / np.cos(np.pi / 6),
0.5 * np.tan(np.pi / 6),
-0.5 * np.tan(np.pi / 6),
]
# Number of hexagons needed
m = len(centers)
# Coordinates for all hexagonal patches
hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0])
hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1])
return hxs, hys, centers, agreggated_value
def _compute_wgs84_hexbin(
lat=None,
lon=None,
lat_range=None,
lon_range=None,
color=None,
nx=None,
agg_func=None,
min_count=None,
):
"""
Computes the lat-lon aggregation at hexagonal bin level.
Latitude and longitude need to be projected to WGS84 before aggregating
in order to display regular hexagons on the map.
Parameters
----------
lat : np.ndarray
Array of latitudes (shape N)
lon : np.ndarray
Array of longitudes (shape N)
lat_range : np.ndarray
Min and max latitudes (shape 2)
lon_range : np.ndarray
Min and max longitudes (shape 2)
color : np.ndarray
Metric to aggregate at hexagon level (shape N)
nx : int
Number of hexagons horizontally
agg_func : function
Numpy compatible aggregator, this function must take a one-dimensional
np.ndarray as input and output a scalar
min_count : int
Minimum number of points in the hexagon for the hexagon to be displayed
Returns
-------
np.ndarray
Lat coordinates of each hexagon (shape M x 6)
np.ndarray
Lon coordinates of each hexagon (shape M x 6)
pd.Series
Unique id for each hexagon, to be used in the geojson data (shape M)
np.ndarray
Aggregated value in each hexagon (shape M)
"""
# Project to WGS 84
x, y = _project_latlon_to_wgs84(lat, lon)
if lat_range is None:
lat_range = np.array([lat.min(), lat.max()])
if lon_range is None:
lon_range = np.array([lon.min(), lon.max()])
x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
hxs, hys, centers, agreggated_value = _compute_hexbin(
x, y, x_range, y_range, color, nx, agg_func, min_count
)
# Convert back to lat-lon
hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
# Create unique feature id based on hexagon center
centers = centers.astype(str)
hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1])
return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value
def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
"""
Creates a geojson of hexagonal features based on the outputs of
_compute_wgs84_hexbin
"""
features = []
if ids is None:
ids = np.arange(len(hexagons_lats))
for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids):
points = np.array([lon, lat]).T.tolist()
points.append(points[0])
features.append(
dict(
type="Feature",
id=idx,
geometry=dict(type="Polygon", coordinates=[points]),
)
)
return dict(type="FeatureCollection", features=features)
def create_hexbin_mapbox(
data_frame=None,
lat=None,
lon=None,
color=None,
nx_hexagon=5,
agg_func=None,
animation_frame=None,
color_discrete_sequence=None,
color_discrete_map={},
labels={},
color_continuous_scale=None,
range_color=None,
color_continuous_midpoint=None,
opacity=None,
zoom=None,
center=None,
mapbox_style=None,
title=None,
template=None,
width=None,
height=None,
min_count=None,
show_original_data=False,
original_data_marker=None,
):
"""
Returns a figure aggregating scattered points into connected hexagons
"""
args = build_dataframe(args=locals(), constructor=None)
if agg_func is None:
agg_func = np.mean
lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values
lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values
hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
lat=args["data_frame"][args["lat"]].values,
lon=args["data_frame"][args["lon"]].values,
lat_range=lat_range,
lon_range=lon_range,
color=None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
)
geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
if zoom is None:
if height is None and width is None:
mapDim = dict(height=450, width=450)
elif height is None and width is not None:
mapDim = dict(height=450, width=width)
elif height is not None and width is None:
mapDim = dict(height=height, width=height)
else:
mapDim = dict(height=height, width=width)
zoom = _getBoundsZoomLevel(
lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim
)
if center is None:
center = dict(lat=lat_range.mean(), lon=lon_range.mean())
if args["animation_frame"] is not None:
groups = args["data_frame"].groupby(args["animation_frame"]).groups
else:
groups = {0: args["data_frame"].index}
agg_data_frame_list = []
for frame, index in groups.items():
df = args["data_frame"].loc[index]
_, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
lat=df[args["lat"]].values,
lon=df[args["lon"]].values,
lat_range=lat_range,
lon_range=lon_range,
color=df[args["color"]].values if args["color"] else None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
)
agg_data_frame_list.append(
pd.DataFrame(
np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"]
)
)
agg_data_frame = (
pd.concat(agg_data_frame_list, axis=0, keys=groups.keys())
.rename_axis(index=("frame", "index"))
.reset_index("frame")
)
agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"])
if range_color is None:
range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]
fig = choropleth_mapbox(
data_frame=agg_data_frame,
geojson=geojson,
locations="locations",
color="color",
hover_data={"color": True, "locations": False, "frame": False},
animation_frame=("frame" if args["animation_frame"] is not None else None),
color_discrete_sequence=color_discrete_sequence,
color_discrete_map=color_discrete_map,
labels=labels,
color_continuous_scale=color_continuous_scale,
range_color=range_color,
color_continuous_midpoint=color_continuous_midpoint,
opacity=opacity,
zoom=zoom,
center=center,
mapbox_style=mapbox_style,
title=title,
template=template,
width=width,
height=height,
)
if show_original_data:
original_fig = scatter_mapbox(
data_frame=(
args["data_frame"].sort_values(by=args["animation_frame"])
if args["animation_frame"] is not None
else args["data_frame"]
),
lat=args["lat"],
lon=args["lon"],
animation_frame=args["animation_frame"],
)
original_fig.data[0].hoverinfo = "skip"
original_fig.data[0].hovertemplate = None
original_fig.data[0].marker = original_data_marker
fig.add_trace(original_fig.data[0])
if args["animation_frame"] is not None:
for i in range(len(original_fig.frames)):
original_fig.frames[i].data[0].hoverinfo = "skip"
original_fig.frames[i].data[0].hovertemplate = None
original_fig.frames[i].data[0].marker = original_data_marker
fig.frames[i].data = [
fig.frames[i].data[0],
original_fig.frames[i].data[0],
]
return fig
create_hexbin_mapbox.__doc__ = make_docstring(
create_hexbin_mapbox,
override_dict=dict(
nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
agg_func=[
"function",
"Numpy array aggregator, it must take as input a 1D array",
"and output a scalar value.",
],
min_count=[
"int",
"Minimum number of points in a hexagon for it to be displayed.",
"If None and color is not set, display all hexagons.",
"If None and color is set, only display hexagons that contain points.",
],
show_original_data=[
"bool",
"Whether to show the original data on top of the hexbin aggregation.",
],
original_data_marker=["dict", "Scattermapbox marker options."],
),
)
@@ -0,0 +1,297 @@
from __future__ import absolute_import
from plotly import exceptions
from plotly.graph_objs import graph_objs
from plotly.figure_factory import utils
# Default colours for finance charts
_DEFAULT_INCREASING_COLOR = "#3D9970" # http://clrs.cc
_DEFAULT_DECREASING_COLOR = "#FF4136"
def validate_ohlc(open, high, low, close, direction, **kwargs):
"""
ohlc and candlestick specific validations
Specifically, this checks that the high value is the greatest value and
the low value is the lowest value in each unit.
See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
for params
:raises: (PlotlyError) If the high value is not the greatest value in
each unit.
:raises: (PlotlyError) If the low value is not the lowest value in each
unit.
:raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
"""
for lst in [open, low, close]:
for index in range(len(high)):
if high[index] < lst[index]:
raise exceptions.PlotlyError(
"Oops! Looks like some of "
"your high values are less "
"the corresponding open, "
"low, or close values. "
"Double check that your data "
"is entered in O-H-L-C order"
)
for lst in [open, high, close]:
for index in range(len(low)):
if low[index] > lst[index]:
raise exceptions.PlotlyError(
"Oops! Looks like some of "
"your low values are greater "
"than the corresponding high"
", open, or close values. "
"Double check that your data "
"is entered in O-H-L-C order"
)
direction_opts = ("increasing", "decreasing", "both")
if direction not in direction_opts:
raise exceptions.PlotlyError(
"direction must be defined as " "'increasing', 'decreasing', or " "'both'"
)
def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
"""
Makes increasing ohlc sticks
_make_increasing_ohlc() and _make_decreasing_ohlc separate the
increasing trace from the decreasing trace so kwargs (such as
color) can be passed separately to increasing or decreasing traces
when direction is set to 'increasing' or 'decreasing' in
FigureFactory.create_candlestick()
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
sticks.
"""
(flat_increase_x, flat_increase_y, text_increase) = _OHLC(
open, high, low, close, dates
).get_increase()
if "name" in kwargs:
showlegend = True
else:
kwargs.setdefault("name", "Increasing")
showlegend = False
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR, width=1))
kwargs.setdefault("text", text_increase)
ohlc_incr = dict(
type="scatter",
x=flat_increase_x,
y=flat_increase_y,
mode="lines",
showlegend=showlegend,
**kwargs,
)
return ohlc_incr
def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
"""
Makes decreasing ohlc sticks
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
sticks.
"""
(flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC(
open, high, low, close, dates
).get_decrease()
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR, width=1))
kwargs.setdefault("text", text_decrease)
kwargs.setdefault("showlegend", False)
kwargs.setdefault("name", "Decreasing")
ohlc_decr = dict(
type="scatter", x=flat_decrease_x, y=flat_decrease_y, mode="lines", **kwargs
)
return ohlc_decr
def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Ohlc`
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing
:param (list) dates: list of datetime objects. Default: None
:param (string) direction: direction can be 'increasing', 'decreasing',
or 'both'. When the direction is 'increasing', the returned figure
consists of all units where the close value is greater than the
corresponding open value, and when the direction is 'decreasing',
the returned figure consists of all units where the close value is
less than or equal to the corresponding open value. When the
direction is 'both', both increasing and decreasing units are
returned. Default: 'both'
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
These kwargs describe other attributes about the ohlc Scatter trace
such as the color or the legend name. For more information on valid
kwargs call help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of an ohlc chart figure.
Example 1: Simple OHLC chart from a Pandas DataFrame
>>> from plotly.figure_factory import create_ohlc
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index)
>>> fig.show()
"""
if dates is not None:
utils.validate_equal_length(open, high, low, close, dates)
else:
utils.validate_equal_length(open, high, low, close)
validate_ohlc(open, high, low, close, direction, **kwargs)
if direction == "increasing":
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_incr]
elif direction == "decreasing":
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_decr]
else:
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_incr, ohlc_decr]
layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")
return graph_objs.Figure(data=data, layout=layout)
class _OHLC(object):
"""
Refer to FigureFactory.create_ohlc_increase() for docstring.
"""
def __init__(self, open, high, low, close, dates, **kwargs):
self.open = open
self.high = high
self.low = low
self.close = close
self.empty = [None] * len(open)
self.dates = dates
self.all_x = []
self.all_y = []
self.increase_x = []
self.increase_y = []
self.decrease_x = []
self.decrease_y = []
self.get_all_xy()
self.separate_increase_decrease()
def get_all_xy(self):
"""
Zip data to create OHLC shape
OHLC shape: low to high vertical bar with
horizontal branches for open and close values.
If dates were added, the smallest date difference is calculated and
multiplied by .2 to get the length of the open and close branches.
If no date data was provided, the x-axis is a list of integers and the
length of the open and close branches is .2.
"""
self.all_y = list(
zip(
self.open,
self.open,
self.high,
self.low,
self.close,
self.close,
self.empty,
)
)
if self.dates is not None:
date_dif = []
for i in range(len(self.dates) - 1):
date_dif.append(self.dates[i + 1] - self.dates[i])
date_dif_min = (min(date_dif)) / 5
self.all_x = [
[x - date_dif_min, x, x, x, x, x + date_dif_min, None]
for x in self.dates
]
else:
self.all_x = [
[x - 0.2, x, x, x, x, x + 0.2, None] for x in range(len(self.open))
]
def separate_increase_decrease(self):
"""
Separate data into two groups: increase and decrease
(1) Increase, where close > open and
(2) Decrease, where close <= open
"""
for index in range(len(self.open)):
if self.close[index] is None:
pass
elif self.close[index] > self.open[index]:
self.increase_x.append(self.all_x[index])
self.increase_y.append(self.all_y[index])
else:
self.decrease_x.append(self.all_x[index])
self.decrease_y.append(self.all_y[index])
def get_increase(self):
"""
Flatten increase data and get increase text
:rtype (list, list, list): flat_increase_x: x-values for the increasing
trace, flat_increase_y: y=values for the increasing trace and
text_increase: hovertext for the increasing trace
"""
flat_increase_x = utils.flatten(self.increase_x)
flat_increase_y = utils.flatten(self.increase_y)
text_increase = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
len(self.increase_x)
)
return flat_increase_x, flat_increase_y, text_increase
def get_decrease(self):
"""
Flatten decrease data and get decrease text
:rtype (list, list, list): flat_decrease_x: x-values for the decreasing
trace, flat_decrease_y: y=values for the decreasing trace and
text_decrease: hovertext for the decreasing trace
"""
flat_decrease_x = utils.flatten(self.decrease_x)
flat_decrease_y = utils.flatten(self.decrease_y)
text_decrease = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
len(self.decrease_x)
)
return flat_decrease_x, flat_decrease_y, text_decrease
@@ -0,0 +1,267 @@
from __future__ import absolute_import
import math
from plotly import exceptions
from plotly.graph_objs import graph_objs
from plotly.figure_factory import utils
def create_quiver(
x, y, u, v, scale=0.1, arrow_scale=0.3, angle=math.pi / 9, scaleratio=None, **kwargs
):
"""
Returns data for a quiver plot.
:param (list|ndarray) x: x coordinates of the arrow locations
:param (list|ndarray) y: y coordinates of the arrow locations
:param (list|ndarray) u: x components of the arrow vectors
:param (list|ndarray) v: y components of the arrow vectors
:param (float in [0,1]) scale: scales size of the arrows(ideally to
avoid overlap). Default = .1
:param (float in [0,1]) arrow_scale: value multiplied to length of barb
to get length of arrowhead. Default = .3
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (positive float) scaleratio: the ratio between the scale of the y-axis
and the scale of the x-axis (scale_y / scale_x). Default = None, the
scale ratio is not fixed.
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
for more information on valid kwargs call
help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of quiver figure.
Example 1: Trivial Quiver
>>> from plotly.figure_factory import create_quiver
>>> import math
>>> # 1 Arrow from (0,0) to (1,1)
>>> fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
>>> fig.show()
Example 2: Quiver plot using meshgrid
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> import math
>>> # Add data
>>> x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
>>> u = np.cos(x)*y
>>> v = np.sin(x)*y
>>> #Create quiver
>>> fig = create_quiver(x, y, u, v)
>>> fig.show()
Example 3: Styling the quiver plot
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> import math
>>> # Add data
>>> x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
... np.arange(-math.pi, math.pi, .5))
>>> u = np.cos(x)*y
>>> v = np.sin(x)*y
>>> # Create quiver
>>> fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
... name='Wind Velocity', line=dict(width=1))
>>> # Add title to layout
>>> fig.update_layout(title='Quiver Plot') # doctest: +SKIP
>>> fig.show()
Example 4: Forcing a fix scale ratio to maintain the arrow length
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> # Add data
>>> x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
>>> u = x
>>> v = y
>>> angle = np.arctan(v / u)
>>> norm = 0.25
>>> u = norm * np.cos(angle)
>>> v = norm * np.sin(angle)
>>> # Create quiver with a fix scale ratio
>>> fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)
>>> fig.show()
"""
utils.validate_equal_length(x, y, u, v)
utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
if scaleratio is None:
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
else:
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)
barb_x, barb_y = quiver_obj.get_barbs()
arrow_x, arrow_y = quiver_obj.get_quiver_arrows()
quiver_plot = graph_objs.Scatter(
x=barb_x + arrow_x, y=barb_y + arrow_y, mode="lines", **kwargs
)
data = [quiver_plot]
if scaleratio is None:
layout = graph_objs.Layout(hovermode="closest")
else:
layout = graph_objs.Layout(
hovermode="closest", yaxis=dict(scaleratio=scaleratio, scaleanchor="x")
)
return graph_objs.Figure(data=data, layout=layout)
class _Quiver(object):
"""
Refer to FigureFactory.create_quiver() for docstring
"""
def __init__(self, x, y, u, v, scale, arrow_scale, angle, scaleratio=1, **kwargs):
try:
x = utils.flatten(x)
except exceptions.PlotlyError:
pass
try:
y = utils.flatten(y)
except exceptions.PlotlyError:
pass
try:
u = utils.flatten(u)
except exceptions.PlotlyError:
pass
try:
v = utils.flatten(v)
except exceptions.PlotlyError:
pass
self.x = x
self.y = y
self.u = u
self.v = v
self.scale = scale
self.scaleratio = scaleratio
self.arrow_scale = arrow_scale
self.angle = angle
self.end_x = []
self.end_y = []
self.scale_uv()
barb_x, barb_y = self.get_barbs()
arrow_x, arrow_y = self.get_quiver_arrows()
def scale_uv(self):
"""
Scales u and v to avoid overlap of the arrows.
u and v are added to x and y to get the
endpoints of the arrows so a smaller scale value will
result in less overlap of arrows.
"""
self.u = [i * self.scale * self.scaleratio for i in self.u]
self.v = [i * self.scale for i in self.v]
def get_barbs(self):
"""
Creates x and y startpoint and endpoint pairs
After finding the endpoint of each barb this zips startpoint and
endpoint pairs to create 2 lists: x_values for barbs and y values
for barbs
:rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
x_value pairs separated by a None to create the barb of the arrow,
and list of startpoint and endpoint y_value pairs separated by a
None to create the barb of the arrow.
"""
self.end_x = [i + j for i, j in zip(self.x, self.u)]
self.end_y = [i + j for i, j in zip(self.y, self.v)]
empty = [None] * len(self.x)
barb_x = utils.flatten(zip(self.x, self.end_x, empty))
barb_y = utils.flatten(zip(self.y, self.end_y, empty))
return barb_x, barb_y
def get_quiver_arrows(self):
"""
Creates lists of x and y values to plot the arrows
Gets length of each barb then calculates the length of each side of
the arrow. Gets angle of barb and applies angle to each side of the
arrowhead. Next uses arrow_scale to scale the length of arrowhead and
creates x and y values for arrowhead point1 and point2. Finally x and y
values for point1, endpoint and point2s for each arrowhead are
separated by a None and zipped to create lists of x and y values for
the arrows.
:rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
x_values separated by a None to create the arrowhead and list of
point1, endpoint, point2 y_values separated by a None to create
the barb of the arrow.
"""
dif_x = [i - j for i, j in zip(self.end_x, self.x)]
dif_y = [i - j for i, j in zip(self.end_y, self.y)]
# Get barb lengths(default arrow length = 30% barb length)
barb_len = [None] * len(self.x)
for index in range(len(barb_len)):
barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index])
# Make arrow lengths
arrow_len = [None] * len(self.x)
arrow_len = [i * self.arrow_scale for i in barb_len]
# Get barb angles
barb_ang = [None] * len(self.x)
for index in range(len(barb_ang)):
barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio)
# Set angles to create arrow
ang1 = [i + self.angle for i in barb_ang]
ang2 = [i - self.angle for i in barb_ang]
cos_ang1 = [None] * len(ang1)
for index in range(len(ang1)):
cos_ang1[index] = math.cos(ang1[index])
seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
sin_ang1 = [None] * len(ang1)
for index in range(len(ang1)):
sin_ang1[index] = math.sin(ang1[index])
seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
cos_ang2 = [None] * len(ang2)
for index in range(len(ang2)):
cos_ang2[index] = math.cos(ang2[index])
seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
sin_ang2 = [None] * len(ang2)
for index in range(len(ang2)):
sin_ang2[index] = math.sin(ang2[index])
seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
# Set coordinates to create arrow
for index in range(len(self.end_x)):
point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)]
point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)]
point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
# Combine lists to create arrow
empty = [None] * len(self.end_x)
arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
return arrow_x, arrow_y
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,408 @@
from __future__ import absolute_import
import math
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
np = optional_imports.get_module("numpy")
def validate_streamline(x, y):
"""
Streamline-specific validations
Specifically, this checks that x and y are both evenly spaced,
and that the package numpy is available.
See FigureFactory.create_streamline() for params
:raises: (ImportError) If numpy is not available.
:raises: (PlotlyError) If x is not evenly spaced.
:raises: (PlotlyError) If y is not evenly spaced.
"""
if np is False:
raise ImportError("FigureFactory.create_streamline requires numpy")
for index in range(len(x) - 1):
if ((x[index + 1] - x[index]) - (x[1] - x[0])) > 0.0001:
raise exceptions.PlotlyError(
"x must be a 1 dimensional, " "evenly spaced array"
)
for index in range(len(y) - 1):
if ((y[index + 1] - y[index]) - (y[1] - y[0])) > 0.0001:
raise exceptions.PlotlyError(
"y must be a 1 dimensional, " "evenly spaced array"
)
def create_streamline(
x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=0.09, **kwargs
):
"""
Returns data for a streamline plot.
:param (list|ndarray) x: 1 dimensional, evenly spaced list or array
:param (list|ndarray) y: 1 dimensional, evenly spaced list or array
:param (ndarray) u: 2 dimensional array
:param (ndarray) v: 2 dimensional array
:param (float|int) density: controls the density of streamlines in
plot. This is multiplied by 30 to scale similiarly to other
available streamline functions such as matplotlib.
Default = 1
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
Default = .09
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
for more information on valid kwargs call
help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of streamline figure.
Example 1: Plot simple streamline and increase arrow size
>>> from plotly.figure_factory import create_streamline
>>> import plotly.graph_objects as go
>>> import numpy as np
>>> import math
>>> # Add data
>>> x = np.linspace(-3, 3, 100)
>>> y = np.linspace(-3, 3, 100)
>>> Y, X = np.meshgrid(x, y)
>>> u = -1 - X**2 + Y
>>> v = 1 + X - Y**2
>>> u = u.T # Transpose
>>> v = v.T # Transpose
>>> # Create streamline
>>> fig = create_streamline(x, y, u, v, arrow_scale=.1)
>>> fig.show()
Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
>>> from plotly.figure_factory import create_streamline
>>> import numpy as np
>>> import math
>>> # Add data
>>> N = 50
>>> x_start, x_end = -2.0, 2.0
>>> y_start, y_end = -1.0, 1.0
>>> x = np.linspace(x_start, x_end, N)
>>> y = np.linspace(y_start, y_end, N)
>>> X, Y = np.meshgrid(x, y)
>>> ss = 5.0
>>> x_s, y_s = -1.0, 0.0
>>> # Compute the velocity field on the mesh grid
>>> u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
>>> v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
>>> # Create streamline
>>> fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
>>> # Add source point
>>> point = go.Scatter(x=[x_s], y=[y_s], mode='markers',
... marker_size=14, name='source point')
>>> fig.add_trace(point) # doctest: +SKIP
>>> fig.show()
"""
utils.validate_equal_length(x, y)
utils.validate_equal_length(u, v)
validate_streamline(x, y)
utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
streamline_x, streamline_y = _Streamline(
x, y, u, v, density, angle, arrow_scale
).sum_streamlines()
arrow_x, arrow_y = _Streamline(
x, y, u, v, density, angle, arrow_scale
).get_streamline_arrows()
streamline = graph_objs.Scatter(
x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode="lines", **kwargs
)
data = [streamline]
layout = graph_objs.Layout(hovermode="closest")
return graph_objs.Figure(data=data, layout=layout)
class _Streamline(object):
"""
Refer to FigureFactory.create_streamline() for docstring
"""
def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs):
self.x = np.array(x)
self.y = np.array(y)
self.u = np.array(u)
self.v = np.array(v)
self.angle = angle
self.arrow_scale = arrow_scale
self.density = int(30 * density) # Scale similarly to other functions
self.delta_x = self.x[1] - self.x[0]
self.delta_y = self.y[1] - self.y[0]
self.val_x = self.x
self.val_y = self.y
# Set up spacing
self.blank = np.zeros((self.density, self.density))
self.spacing_x = len(self.x) / float(self.density - 1)
self.spacing_y = len(self.y) / float(self.density - 1)
self.trajectories = []
# Rescale speed onto axes-coordinates
self.u = self.u / (self.x[-1] - self.x[0])
self.v = self.v / (self.y[-1] - self.y[0])
self.speed = np.sqrt(self.u**2 + self.v**2)
# Rescale u and v for integrations.
self.u *= len(self.x)
self.v *= len(self.y)
self.st_x = []
self.st_y = []
self.get_streamlines()
streamline_x, streamline_y = self.sum_streamlines()
arrows_x, arrows_y = self.get_streamline_arrows()
def blank_pos(self, xi, yi):
"""
Set up positions for trajectories to be used with rk4 function.
"""
return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5))
def value_at(self, a, xi, yi):
"""
Set up for RK4 function, based on Bokeh's streamline code
"""
if isinstance(xi, np.ndarray):
self.x = xi.astype(np.int)
self.y = yi.astype(np.int)
else:
self.val_x = np.int(xi)
self.val_y = np.int(yi)
a00 = a[self.val_y, self.val_x]
a01 = a[self.val_y, self.val_x + 1]
a10 = a[self.val_y + 1, self.val_x]
a11 = a[self.val_y + 1, self.val_x + 1]
xt = xi - self.val_x
yt = yi - self.val_y
a0 = a00 * (1 - xt) + a01 * xt
a1 = a10 * (1 - xt) + a11 * xt
return a0 * (1 - yt) + a1 * yt
def rk4_integrate(self, x0, y0):
"""
RK4 forward and back trajectories from the initial conditions.
Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
x and y trajectories then checks length of traj (s in units of axes)
"""
def f(xi, yi):
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
ui = self.value_at(self.u, xi, yi)
vi = self.value_at(self.v, xi, yi)
return ui * dt_ds, vi * dt_ds
def g(xi, yi):
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
ui = self.value_at(self.u, xi, yi)
vi = self.value_at(self.v, xi, yi)
return -ui * dt_ds, -vi * dt_ds
check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and 0 <= yi < len(self.y) - 1)
xb_changes = []
yb_changes = []
def rk4(x0, y0, f):
ds = 0.01
stotal = 0
xi = x0
yi = y0
xb, yb = self.blank_pos(xi, yi)
xf_traj = []
yf_traj = []
while check(xi, yi):
xf_traj.append(xi)
yf_traj.append(yi)
try:
k1x, k1y = f(xi, yi)
k2x, k2y = f(xi + 0.5 * ds * k1x, yi + 0.5 * ds * k1y)
k3x, k3y = f(xi + 0.5 * ds * k2x, yi + 0.5 * ds * k2y)
k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
except IndexError:
break
xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.0
yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.0
if not check(xi, yi):
break
stotal += ds
new_xb, new_yb = self.blank_pos(xi, yi)
if new_xb != xb or new_yb != yb:
if self.blank[new_yb, new_xb] == 0:
self.blank[new_yb, new_xb] = 1
xb_changes.append(new_xb)
yb_changes.append(new_yb)
xb = new_xb
yb = new_yb
else:
break
if stotal > 2:
break
return stotal, xf_traj, yf_traj
sf, xf_traj, yf_traj = rk4(x0, y0, f)
sb, xb_traj, yb_traj = rk4(x0, y0, g)
stotal = sf + sb
x_traj = xb_traj[::-1] + xf_traj[1:]
y_traj = yb_traj[::-1] + yf_traj[1:]
if len(x_traj) < 1:
return None
if stotal > 0.2:
initxb, inityb = self.blank_pos(x0, y0)
self.blank[inityb, initxb] = 1
return x_traj, y_traj
else:
for xb, yb in zip(xb_changes, yb_changes):
self.blank[yb, xb] = 0
return None
def traj(self, xb, yb):
"""
Integrate trajectories
:param (int) xb: results of passing xi through self.blank_pos
:param (int) xy: results of passing yi through self.blank_pos
Calculate each trajectory based on rk4 integrate method.
"""
if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
return
if self.blank[yb, xb] == 0:
t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
if t is not None:
self.trajectories.append(t)
def get_streamlines(self):
"""
Get streamlines by building trajectory set.
"""
for indent in range(self.density // 2):
for xi in range(self.density - 2 * indent):
self.traj(xi + indent, indent)
self.traj(xi + indent, self.density - 1 - indent)
self.traj(indent, xi + indent)
self.traj(self.density - 1 - indent, xi + indent)
self.st_x = [
np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories
]
self.st_y = [
np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories
]
for index in range(len(self.st_x)):
self.st_x[index] = self.st_x[index].tolist()
self.st_x[index].append(np.nan)
for index in range(len(self.st_y)):
self.st_y[index] = self.st_y[index].tolist()
self.st_y[index].append(np.nan)
def get_streamline_arrows(self):
"""
Makes an arrow for each streamline.
Gets angle of streamline at 1/3 mark and creates arrow coordinates
based off of user defined angle and arrow_scale.
:param (array) st_x: x-values for all streamlines
:param (array) st_y: y-values for all streamlines
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
Default = .09
:rtype (list, list) arrows_x: x-values to create arrowhead and
arrows_y: y-values to create arrowhead
"""
arrow_end_x = np.empty((len(self.st_x)))
arrow_end_y = np.empty((len(self.st_y)))
arrow_start_x = np.empty((len(self.st_x)))
arrow_start_y = np.empty((len(self.st_y)))
for index in range(len(self.st_x)):
arrow_end_x[index] = self.st_x[index][int(len(self.st_x[index]) / 3)]
arrow_start_x[index] = self.st_x[index][
(int(len(self.st_x[index]) / 3)) - 1
]
arrow_end_y[index] = self.st_y[index][int(len(self.st_y[index]) / 3)]
arrow_start_y[index] = self.st_y[index][
(int(len(self.st_y[index]) / 3)) - 1
]
dif_x = arrow_end_x - arrow_start_x
dif_y = arrow_end_y - arrow_start_y
orig_err = np.geterr()
np.seterr(divide="ignore", invalid="ignore")
streamline_ang = np.arctan(dif_y / dif_x)
np.seterr(**orig_err)
ang1 = streamline_ang + (self.angle)
ang2 = streamline_ang - (self.angle)
seg1_x = np.cos(ang1) * self.arrow_scale
seg1_y = np.sin(ang1) * self.arrow_scale
seg2_x = np.cos(ang2) * self.arrow_scale
seg2_y = np.sin(ang2) * self.arrow_scale
point1_x = np.empty((len(dif_x)))
point1_y = np.empty((len(dif_y)))
point2_x = np.empty((len(dif_x)))
point2_y = np.empty((len(dif_y)))
for index in range(len(dif_x)):
if dif_x[index] >= 0:
point1_x[index] = arrow_end_x[index] - seg1_x[index]
point1_y[index] = arrow_end_y[index] - seg1_y[index]
point2_x[index] = arrow_end_x[index] - seg2_x[index]
point2_y[index] = arrow_end_y[index] - seg2_y[index]
else:
point1_x[index] = arrow_end_x[index] + seg1_x[index]
point1_y[index] = arrow_end_y[index] + seg1_y[index]
point2_x[index] = arrow_end_x[index] + seg2_x[index]
point2_y[index] = arrow_end_y[index] + seg2_y[index]
space = np.empty((len(point1_x)))
space[:] = np.nan
# Combine arrays into matrix
arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space])
arrows_x = np.array(arrows_x)
arrows_x = arrows_x.flatten("F")
arrows_x = arrows_x.tolist()
# Combine arrays into matrix
arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space])
arrows_y = np.array(arrows_y)
arrows_y = arrows_y.flatten("F")
arrows_y = arrows_y.tolist()
return arrows_x, arrows_y
def sum_streamlines(self):
"""
Makes all streamlines readable as a single trace.
:rtype (list, list): streamline_x: all x values for each streamline
combined into single list and streamline_y: all y values for each
streamline combined into single list
"""
streamline_x = sum(self.st_x, [])
streamline_y = sum(self.st_y, [])
return streamline_x, streamline_y
@@ -0,0 +1,283 @@
from __future__ import absolute_import
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
pd = optional_imports.get_module("pandas")
def validate_table(table_text, font_colors):
"""
Table-specific validations
Check that font_colors is supplied correctly (1, 3, or len(text)
colors).
:raises: (PlotlyError) If font_colors is supplied incorretly.
See FigureFactory.create_table() for params
"""
font_colors_len_options = [1, 3, len(table_text)]
if len(font_colors) not in font_colors_len_options:
raise exceptions.PlotlyError(
"Oops, font_colors should be a list " "of length 1, 3 or len(text)"
)
def create_table(
table_text,
colorscale=None,
font_colors=None,
index=False,
index_title="",
annotation_offset=0.45,
height_constant=30,
hoverinfo="none",
**kwargs,
):
"""
Function that creates data tables.
See also the plotly.graph_objects trace
:class:`plotly.graph_objects.Table`
:param (pandas.Dataframe | list[list]) text: data for table.
:param (str|list[list]) colorscale: Colorscale for table where the
color at value 0 is the header color, .5 is the first table color
and 1 is the second table color. (Set .5 and 1 to avoid the striped
table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
[1, '#ffffff']]
:param (list) font_colors: Color for fonts in table. Can be a single
color, three colors, or a color for each row in the table.
Default=['#000000'] (black text for the entire table)
:param (int) height_constant: Constant multiplied by # of rows to
create table height. Default=30.
:param (bool) index: Create (header-colored) index column index from
Pandas dataframe or list[0] for each list in text. Default=False.
:param (string) index_title: Title for index column. Default=''.
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
These kwargs describe other attributes about the annotated Heatmap
trace such as the colorscale. For more information on valid kwargs
call help(plotly.graph_objs.Heatmap)
Example 1: Simple Plotly Table
>>> from plotly.figure_factory import create_table
>>> text = [['Country', 'Year', 'Population'],
... ['US', 2000, 282200000],
... ['Canada', 2000, 27790000],
... ['US', 2010, 309000000],
... ['Canada', 2010, 34000000]]
>>> table = create_table(text)
>>> table.show()
Example 2: Table with Custom Coloring
>>> from plotly.figure_factory import create_table
>>> text = [['Country', 'Year', 'Population'],
... ['US', 2000, 282200000],
... ['Canada', 2000, 27790000],
... ['US', 2010, 309000000],
... ['Canada', 2010, 34000000]]
>>> table = create_table(text,
... colorscale=[[0, '#000000'],
... [.5, '#80beff'],
... [1, '#cce5ff']],
... font_colors=['#ffffff', '#000000',
... '#000000'])
>>> table.show()
Example 3: Simple Plotly Table with Pandas
>>> from plotly.figure_factory import create_table
>>> import pandas as pd
>>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
>>> df_p = df[0:25]
>>> table_simple = create_table(df_p)
>>> table_simple.show()
"""
# Avoiding mutables in the call signature
colorscale = (
colorscale
if colorscale is not None
else [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]]
)
font_colors = (
font_colors if font_colors is not None else ["#ffffff", "#000000", "#000000"]
)
validate_table(table_text, font_colors)
table_matrix = _Table(
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
).get_table_matrix()
annotations = _Table(
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
).make_table_annotations()
trace = dict(
type="heatmap",
z=table_matrix,
opacity=0.75,
colorscale=colorscale,
showscale=False,
hoverinfo=hoverinfo,
**kwargs,
)
data = [trace]
layout = dict(
annotations=annotations,
height=len(table_matrix) * height_constant + 50,
margin=dict(t=0, b=0, r=0, l=0),
yaxis=dict(
autorange="reversed",
zeroline=False,
gridwidth=2,
ticks="",
dtick=1,
tick0=0.5,
showticklabels=False,
),
xaxis=dict(
zeroline=False,
gridwidth=2,
ticks="",
dtick=1,
tick0=-0.5,
showticklabels=False,
),
)
return graph_objs.Figure(data=data, layout=layout)
class _Table(object):
"""
Refer to TraceFactory.create_table() for docstring
"""
def __init__(
self,
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
):
if pd and isinstance(table_text, pd.DataFrame):
headers = table_text.columns.tolist()
table_text_index = table_text.index.tolist()
table_text = table_text.values.tolist()
table_text.insert(0, headers)
if index:
table_text_index.insert(0, index_title)
for i in range(len(table_text)):
table_text[i].insert(0, table_text_index[i])
self.table_text = table_text
self.colorscale = colorscale
self.font_colors = font_colors
self.index = index
self.annotation_offset = annotation_offset
self.x = range(len(table_text[0]))
self.y = range(len(table_text))
def get_table_matrix(self):
"""
Create z matrix to make heatmap with striped table coloring
:rtype (list[list]) table_matrix: z matrix to make heatmap with striped
table coloring.
"""
header = [0] * len(self.table_text[0])
odd_row = [0.5] * len(self.table_text[0])
even_row = [1] * len(self.table_text[0])
table_matrix = [None] * len(self.table_text)
table_matrix[0] = header
for i in range(1, len(self.table_text), 2):
table_matrix[i] = odd_row
for i in range(2, len(self.table_text), 2):
table_matrix[i] = even_row
if self.index:
for array in table_matrix:
array[0] = 0
return table_matrix
def get_table_font_color(self):
"""
Fill font-color array.
Table text color can vary by row so this extends a single color or
creates an array to set a header color and two alternating colors to
create the striped table pattern.
:rtype (list[list]) all_font_colors: list of font colors for each row
in table.
"""
if len(self.font_colors) == 1:
all_font_colors = self.font_colors * len(self.table_text)
elif len(self.font_colors) == 3:
all_font_colors = list(range(len(self.table_text)))
all_font_colors[0] = self.font_colors[0]
for i in range(1, len(self.table_text), 2):
all_font_colors[i] = self.font_colors[1]
for i in range(2, len(self.table_text), 2):
all_font_colors[i] = self.font_colors[2]
elif len(self.font_colors) == len(self.table_text):
all_font_colors = self.font_colors
else:
all_font_colors = ["#000000"] * len(self.table_text)
return all_font_colors
def make_table_annotations(self):
"""
Generate annotations to fill in table text
:rtype (list) annotations: list of annotations for each cell of the
table.
"""
table_matrix = _Table.get_table_matrix(self)
all_font_colors = _Table.get_table_font_color(self)
annotations = []
for n, row in enumerate(self.table_text):
for m, val in enumerate(row):
# Bold text in header and index
format_text = (
"<b>" + str(val) + "</b>"
if n == 0 or self.index and m < 1
else str(val)
)
# Match font color of index to font color of header
font_color = (
self.font_colors[0] if self.index and m == 0 else all_font_colors[n]
)
annotations.append(
graph_objs.layout.Annotation(
text=format_text,
x=self.x[m] - self.annotation_offset,
y=self.y[n],
xref="x1",
yref="y1",
align="left",
xanchor="left",
font=dict(color=font_color),
showarrow=False,
)
)
return annotations
@@ -0,0 +1,699 @@
from __future__ import absolute_import
import plotly.colors as clrs
from plotly.graph_objs import graph_objs as go
from plotly import exceptions, optional_imports
from plotly import optional_imports
from plotly.graph_objs import graph_objs as go
np = optional_imports.get_module("numpy")
scipy_interp = optional_imports.get_module("scipy.interpolate")
from skimage import measure
# -------------------------- Layout ------------------------------
def _ternary_layout(
title="Ternary contour plot", width=550, height=525, pole_labels=["a", "b", "c"]
):
"""
Layout of ternary contour plot, to be passed to ``go.FigureWidget``
object.
Parameters
==========
title : str or None
Title of ternary plot
width : int
Figure width.
height : int
Figure height.
pole_labels : str, default ['a', 'b', 'c']
Names of the three poles of the triangle.
"""
return dict(
title=title,
width=width,
height=height,
ternary=dict(
sum=1,
aaxis=dict(
title=dict(text=pole_labels[0]), min=0.01, linewidth=2, ticks="outside"
),
baxis=dict(
title=dict(text=pole_labels[1]), min=0.01, linewidth=2, ticks="outside"
),
caxis=dict(
title=dict(text=pole_labels[2]), min=0.01, linewidth=2, ticks="outside"
),
),
showlegend=False,
)
# ------------- Transformations of coordinates -------------------
def _replace_zero_coords(ternary_data, delta=0.0005):
"""
Replaces zero ternary coordinates with delta and normalize the new
triplets (a, b, c).
Parameters
----------
ternary_data : ndarray of shape (N, 3)
delta : float
Small float to regularize logarithm.
Notes
-----
Implements a method
by J. A. Martin-Fernandez, C. Barcelo-Vidal, V. Pawlowsky-Glahn,
Dealing with zeros and missing values in compositional data sets
using nonparametric imputation, Mathematical Geology 35 (2003),
pp 253-278.
"""
zero_mask = ternary_data == 0
is_any_coord_zero = np.any(zero_mask, axis=0)
unity_complement = 1 - delta * is_any_coord_zero
if np.any(unity_complement) < 0:
raise ValueError(
"The provided value of delta led to negative"
"ternary coords.Set a smaller delta"
)
ternary_data = np.where(zero_mask, delta, unity_complement * ternary_data)
return ternary_data
def _ilr_transform(barycentric):
"""
Perform Isometric Log-Ratio on barycentric (compositional) data.
Parameters
----------
barycentric: ndarray of shape (3, N)
Barycentric coordinates.
References
----------
"An algebraic method to compute isometric logratio transformation and
back transformation of compositional data", Jarauta-Bragulat, E.,
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
Intl Assoc for Math Geology, 2003, pp 31-30.
"""
barycentric = np.asarray(barycentric)
x_0 = np.log(barycentric[0] / barycentric[1]) / np.sqrt(2)
x_1 = (
1.0 / np.sqrt(6) * np.log(barycentric[0] * barycentric[1] / barycentric[2] ** 2)
)
ilr_tdata = np.stack((x_0, x_1))
return ilr_tdata
def _ilr_inverse(x):
"""
Perform inverse Isometric Log-Ratio (ILR) transform to retrieve
barycentric (compositional) data.
Parameters
----------
x : array of shape (2, N)
Coordinates in ILR space.
References
----------
"An algebraic method to compute isometric logratio transformation and
back transformation of compositional data", Jarauta-Bragulat, E.,
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
Intl Assoc for Math Geology, 2003, pp 31-30.
"""
x = np.array(x)
matrix = np.array([[0.5, 1, 1.0], [-0.5, 1, 1.0], [0.0, 0.0, 1.0]])
s = np.sqrt(2) / 2
t = np.sqrt(3 / 2)
Sk = np.einsum("ik, kj -> ij", np.array([[s, t], [-s, t]]), x)
Z = -np.log(1 + np.exp(Sk).sum(axis=0))
log_barycentric = np.einsum(
"ik, kj -> ij", matrix, np.stack((2 * s * x[0], t * x[1], Z))
)
iilr_tdata = np.exp(log_barycentric)
return iilr_tdata
def _transform_barycentric_cartesian():
"""
Returns the transformation matrix from barycentric to Cartesian
coordinates and conversely.
"""
# reference triangle
tri_verts = np.array([[0.5, np.sqrt(3) / 2], [0, 0], [1, 0]])
M = np.array([tri_verts[:, 0], tri_verts[:, 1], np.ones(3)])
return M, np.linalg.inv(M)
def _prepare_barycentric_coord(b_coords):
"""
Check ternary coordinates and return the right barycentric coordinates.
"""
if not isinstance(b_coords, (list, np.ndarray)):
raise ValueError(
"Data should be either an array of shape (n,m),"
"or a list of n m-lists, m=2 or 3"
)
b_coords = np.asarray(b_coords)
if b_coords.shape[0] not in (2, 3):
raise ValueError(
"A point should have 2 (a, b) or 3 (a, b, c)" "barycentric coordinates"
)
if (
(len(b_coords) == 3)
and not np.allclose(b_coords.sum(axis=0), 1, rtol=0.01)
and not np.allclose(b_coords.sum(axis=0), 100, rtol=0.01)
):
msg = "The sum of coordinates should be 1 or 100 for all data points"
raise ValueError(msg)
if len(b_coords) == 2:
A, B = b_coords
C = 1 - (A + B)
else:
A, B, C = b_coords / b_coords.sum(axis=0)
if np.any(np.stack((A, B, C)) < 0):
raise ValueError("Barycentric coordinates should be positive.")
return np.stack((A, B, C))
def _compute_grid(coordinates, values, interp_mode="ilr"):
"""
Transform data points with Cartesian or ILR mapping, then Compute
interpolation on a regular grid.
Parameters
==========
coordinates : array-like
Barycentric coordinates of data points.
values : 1-d array-like
Data points, field to be represented as contours.
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours.
"""
if interp_mode == "cartesian":
M, invM = _transform_barycentric_cartesian()
coord_points = np.einsum("ik, kj -> ij", M, coordinates)
elif interp_mode == "ilr":
coordinates = _replace_zero_coords(coordinates)
coord_points = _ilr_transform(coordinates)
else:
raise ValueError("interp_mode should be cartesian or ilr")
xx, yy = coord_points[:2]
x_min, x_max = xx.min(), xx.max()
y_min, y_max = yy.min(), yy.max()
n_interp = max(200, int(np.sqrt(len(values))))
gr_x = np.linspace(x_min, x_max, n_interp)
gr_y = np.linspace(y_min, y_max, n_interp)
grid_x, grid_y = np.meshgrid(gr_x, gr_y)
# We use cubic interpolation, except outside of the convex hull
# of data points where we use nearest neighbor values.
grid_z = scipy_interp.griddata(
coord_points[:2].T, values, (grid_x, grid_y), method="cubic"
)
grid_z_other = scipy_interp.griddata(
coord_points[:2].T, values, (grid_x, grid_y), method="nearest"
)
# mask_nan = np.isnan(grid_z)
# grid_z[mask_nan] = grid_z_other[mask_nan]
return grid_z, gr_x, gr_y
# ----------------------- Contour traces ----------------------
def _polygon_area(x, y):
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
def _colors(ncontours, colormap=None):
"""
Return a list of ``ncontours`` colors from the ``colormap`` colorscale.
"""
if colormap in clrs.PLOTLY_SCALES.keys():
cmap = clrs.PLOTLY_SCALES[colormap]
else:
raise exceptions.PlotlyError(
"Colorscale must be a valid Plotly Colorscale."
"The available colorscale names are {}".format(clrs.PLOTLY_SCALES.keys())
)
values = np.linspace(0, 1, ncontours)
vals_cmap = np.array([pair[0] for pair in cmap])
cols = np.array([pair[1] for pair in cmap])
inds = np.searchsorted(vals_cmap, values)
if "#" in cols[0]: # for Viridis
cols = [clrs.label_rgb(clrs.hex_to_rgb(col)) for col in cols]
colors = [cols[0]]
for ind, val in zip(inds[1:], values[1:]):
val1, val2 = vals_cmap[ind - 1], vals_cmap[ind]
interm = (val - val1) / (val2 - val1)
col = clrs.find_intermediate_color(
cols[ind - 1], cols[ind], interm, colortype="rgb"
)
colors.append(col)
return colors
def _is_invalid_contour(x, y):
"""
Utility function for _contour_trace
Contours with an area of the order as 1 pixel are considered spurious.
"""
too_small = np.all(np.abs(x - x[0]) < 2) and np.all(np.abs(y - y[0]) < 2)
return too_small
def _extract_contours(im, values, colors):
"""
Utility function for _contour_trace.
In ``im`` only one part of the domain has valid values (corresponding
to a subdomain where barycentric coordinates are well defined). When
computing contours, we need to assign values outside of this domain.
We can choose a value either smaller than all the values inside the
valid domain, or larger. This value must be chose with caution so that
no spurious contours are added. For example, if the boundary of the valid
domain has large values and the outer value is set to a small one, all
intermediate contours will be added at the boundary.
Therefore, we compute the two sets of contours (with an outer value
smaller of larger than all values in the valid domain), and choose
the value resulting in a smaller total number of contours. There might
be a faster way to do this, but it works...
"""
mask_nan = np.isnan(im)
im_min, im_max = (
im[np.logical_not(mask_nan)].min(),
im[np.logical_not(mask_nan)].max(),
)
zz_min = np.copy(im)
zz_min[mask_nan] = 2 * im_min
zz_max = np.copy(im)
zz_max[mask_nan] = 2 * im_max
all_contours1, all_values1, all_areas1, all_colors1 = [], [], [], []
all_contours2, all_values2, all_areas2, all_colors2 = [], [], [], []
for i, val in enumerate(values):
contour_level1 = measure.find_contours(zz_min, val)
contour_level2 = measure.find_contours(zz_max, val)
all_contours1.extend(contour_level1)
all_contours2.extend(contour_level2)
all_values1.extend([val] * len(contour_level1))
all_values2.extend([val] * len(contour_level2))
all_areas1.extend(
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level1]
)
all_areas2.extend(
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level2]
)
all_colors1.extend([colors[i]] * len(contour_level1))
all_colors2.extend([colors[i]] * len(contour_level2))
if len(all_contours1) <= len(all_contours2):
return all_contours1, all_values1, all_areas1, all_colors1
else:
return all_contours2, all_values2, all_areas2, all_colors2
def _add_outer_contour(
all_contours,
all_values,
all_areas,
all_colors,
values,
val_outer,
v_min,
v_max,
colors,
color_min,
color_max,
):
"""
Utility function for _contour_trace
Adds the background color to fill gaps outside of computed contours.
To compute the background color, the color of the contour with largest
area (``val_outer``) is used. As background color, we choose the next
color value in the direction of the extrema of the colormap.
Then we add information for the outer contour for the different lists
provided as arguments.
A discrete colormap with all used colors is also returned (to be used
by colorscale trace).
"""
# The exact value of outer contour is not used when defining the trace
outer_contour = 20 * np.array([[0, 0, 1], [0, 1, 0.5]]).T
all_contours = [outer_contour] + all_contours
delta_values = np.diff(values)[0]
values = np.concatenate(
([values[0] - delta_values], values, [values[-1] + delta_values])
)
colors = np.concatenate(([color_min], colors, [color_max]))
index = np.nonzero(values == val_outer)[0][0]
if index < len(values) / 2:
index -= 1
else:
index += 1
all_colors = [colors[index]] + all_colors
all_values = [values[index]] + all_values
all_areas = [0] + all_areas
used_colors = [color for color in colors if color in all_colors]
# Define discrete colorscale
color_number = len(used_colors)
scale = np.linspace(0, 1, color_number + 1)
discrete_cm = []
for i, color in enumerate(used_colors):
discrete_cm.append([scale[i], used_colors[i]])
discrete_cm.append([scale[i + 1], used_colors[i]])
discrete_cm.append([scale[color_number], used_colors[color_number - 1]])
return all_contours, all_values, all_areas, all_colors, discrete_cm
def _contour_trace(
x,
y,
z,
ncontours=None,
colorscale="Electric",
linecolor="rgb(150,150,150)",
interp_mode="llr",
coloring=None,
v_min=0,
v_max=1,
):
"""
Contour trace in Cartesian coordinates.
Parameters
==========
x, y : array-like
Cartesian coordinates
z : array-like
Field to be represented as contours.
ncontours : int or None
Number of contours to display (determined automatically if None).
colorscale : None or str (Plotly colormap)
colorscale of the contours.
linecolor : rgb color
Color used for lines. If ``colorscale`` is not None, line colors are
determined from ``colorscale`` instead.
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours. If 'irl',
ILR (Isometric Log-Ratio) of compositional data is performed. If
'cartesian', contours are determined in Cartesian space.
coloring : None or 'lines'
How to display contour. Filled contours if None, lines if ``lines``.
vmin, vmax : float
Bounds of interval of values used for the colorspace
Notes
=====
"""
# Prepare colors
# We do not take extrema, for example for one single contour
# the color will be the middle point of the colormap
colors = _colors(ncontours + 2, colorscale)
# Values used for contours, extrema are not used
# For example for a binary array [0, 1], the value of
# the contour for ncontours=1 is 0.5.
values = np.linspace(v_min, v_max, ncontours + 2)
color_min, color_max = colors[0], colors[-1]
colors = colors[1:-1]
values = values[1:-1]
# Color of line contours
if linecolor is None:
linecolor = "rgb(150, 150, 150)"
else:
colors = [linecolor] * ncontours
# Retrieve all contours
all_contours, all_values, all_areas, all_colors = _extract_contours(
z, values, colors
)
# Now sort contours by decreasing area
order = np.argsort(all_areas)[::-1]
# Add outer contour
all_contours, all_values, all_areas, all_colors, discrete_cm = _add_outer_contour(
all_contours,
all_values,
all_areas,
all_colors,
values,
all_values[order[0]],
v_min,
v_max,
colors,
color_min,
color_max,
)
order = np.concatenate(([0], order + 1))
# Compute traces, in the order of decreasing area
traces = []
M, invM = _transform_barycentric_cartesian()
dx = (x.max() - x.min()) / x.size
dy = (y.max() - y.min()) / y.size
for index in order:
y_contour, x_contour = all_contours[index].T
val = all_values[index]
if interp_mode == "cartesian":
bar_coords = np.dot(
invM,
np.stack((dx * x_contour, dy * y_contour, np.ones(x_contour.shape))),
)
elif interp_mode == "ilr":
bar_coords = _ilr_inverse(
np.stack((dx * x_contour + x.min(), dy * y_contour + y.min()))
)
if index == 0: # outer triangle
a = np.array([1, 0, 0])
b = np.array([0, 1, 0])
c = np.array([0, 0, 1])
else:
a, b, c = bar_coords
if _is_invalid_contour(x_contour, y_contour):
continue
_col = all_colors[index] if coloring == "lines" else linecolor
trace = dict(
type="scatterternary",
a=a,
b=b,
c=c,
mode="lines",
line=dict(color=_col, shape="spline", width=1),
fill="toself",
fillcolor=all_colors[index],
showlegend=True,
hoverinfo="skip",
name="%.3f" % val,
)
if coloring == "lines":
trace["fill"] = None
traces.append(trace)
return traces, discrete_cm
# -------------------- Figure Factory for ternary contour -------------
def create_ternary_contour(
coordinates,
values,
pole_labels=["a", "b", "c"],
width=500,
height=500,
ncontours=None,
showscale=False,
coloring=None,
colorscale="Bluered",
linecolor=None,
title=None,
interp_mode="ilr",
showmarkers=False,
):
"""
Ternary contour plot.
Parameters
----------
coordinates : list or ndarray
Barycentric coordinates of shape (2, N) or (3, N) where N is the
number of data points. The sum of the 3 coordinates is expected
to be 1 for all data points.
values : array-like
Data points of field to be represented as contours.
pole_labels : str, default ['a', 'b', 'c']
Names of the three poles of the triangle.
width : int
Figure width.
height : int
Figure height.
ncontours : int or None
Number of contours to display (determined automatically if None).
showscale : bool, default False
If True, a colorbar showing the color scale is displayed.
coloring : None or 'lines'
How to display contour. Filled contours if None, lines if ``lines``.
colorscale : None or str (Plotly colormap)
colorscale of the contours.
linecolor : None or rgb color
Color used for lines. ``colorscale`` has to be set to None, otherwise
line colors are determined from ``colorscale``.
title : str or None
Title of ternary plot
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours. If 'irl',
ILR (Isometric Log-Ratio) of compositional data is performed. If
'cartesian', contours are determined in Cartesian space.
showmarkers : bool, default False
If True, markers corresponding to input compositional points are
superimposed on contours, using the same colorscale.
Examples
========
Example 1: ternary contour plot with filled contours
>>> import plotly.figure_factory as ff
>>> import numpy as np
>>> # Define coordinates
>>> a, b = np.mgrid[0:1:20j, 0:1:20j]
>>> mask = a + b <= 1
>>> a = a[mask].ravel()
>>> b = b[mask].ravel()
>>> c = 1 - a - b
>>> # Values to be displayed as contours
>>> z = a * b * c
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z)
>>> fig.show()
It is also possible to give only two barycentric coordinates for each
point, since the sum of the three coordinates is one:
>>> fig = ff.create_ternary_contour(np.stack((a, b)), z)
Example 2: ternary contour plot with line contours
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines')
Example 3: customize number of contours
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, ncontours=8)
Example 4: superimpose contour plot and original data as markers
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines',
... showmarkers=True)
Example 5: customize title and pole labels
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z,
... title='Ternary plot',
... pole_labels=['clay', 'quartz', 'fledspar'])
"""
if scipy_interp is None:
raise ImportError(
"""\
The create_ternary_contour figure factory requires the scipy package"""
)
sk_measure = optional_imports.get_module("skimage")
if sk_measure is None:
raise ImportError(
"""\
The create_ternary_contour figure factory requires the scikit-image
package"""
)
if colorscale is None:
showscale = False
if ncontours is None:
ncontours = 5
coordinates = _prepare_barycentric_coord(coordinates)
v_min, v_max = values.min(), values.max()
grid_z, gr_x, gr_y = _compute_grid(coordinates, values, interp_mode=interp_mode)
layout = _ternary_layout(
pole_labels=pole_labels, width=width, height=height, title=title
)
contour_trace, discrete_cm = _contour_trace(
gr_x,
gr_y,
grid_z,
ncontours=ncontours,
colorscale=colorscale,
linecolor=linecolor,
interp_mode=interp_mode,
coloring=coloring,
v_min=v_min,
v_max=v_max,
)
fig = go.Figure(data=contour_trace, layout=layout)
opacity = 1 if showmarkers else 0
a, b, c = coordinates
hovertemplate = (
pole_labels[0]
+ ": %{a:.3f}<br>"
+ pole_labels[1]
+ ": %{b:.3f}<br>"
+ pole_labels[2]
+ ": %{c:.3f}<br>"
"z: %{marker.color:.3f}<extra></extra>"
)
fig.add_scatterternary(
a=a,
b=b,
c=c,
mode="markers",
marker={
"color": values,
"colorscale": colorscale,
"line": {"color": "rgb(120, 120, 120)", "width": int(coloring != "lines")},
},
opacity=opacity,
hovertemplate=hovertemplate,
)
if showscale:
if not showmarkers:
colorscale = discrete_cm
colorbar = dict(
{
"type": "scatterternary",
"a": [None],
"b": [None],
"c": [None],
"marker": {
"cmin": values.min(),
"cmax": values.max(),
"colorscale": colorscale,
"showscale": True,
},
"mode": "markers",
}
)
fig.add_trace(colorbar)
return fig
@@ -0,0 +1,511 @@
from __future__ import absolute_import
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
np = optional_imports.get_module("numpy")
def map_face2color(face, colormap, scale, vmin, vmax):
"""
Normalize facecolor values by vmin/vmax and return rgb-color strings
This function takes a tuple color along with a colormap and a minimum
(vmin) and maximum (vmax) range of possible mean distances for the
given parametrized surface. It returns an rgb color based on the mean
distance between vmin and vmax
"""
if vmin >= vmax:
raise exceptions.PlotlyError(
"Incorrect relation between vmin "
"and vmax. The vmin value cannot be "
"bigger than or equal to the value "
"of vmax."
)
if len(colormap) == 1:
# color each triangle face with the same color in colormap
face_color = colormap[0]
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
if face == vmax:
# pick last color in colormap
face_color = colormap[-1]
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
else:
if scale is None:
# find the normalized distance t of a triangle face between
# vmin and vmax where the distance is between 0 and 1
t = (face - vmin) / float((vmax - vmin))
low_color_index = int(t / (1.0 / (len(colormap) - 1)))
face_color = clrs.find_intermediate_color(
colormap[low_color_index],
colormap[low_color_index + 1],
t * (len(colormap) - 1) - low_color_index,
)
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
else:
# find the face color for a non-linearly interpolated scale
t = (face - vmin) / float((vmax - vmin))
low_color_index = 0
for k in range(len(scale) - 1):
if scale[k] <= t < scale[k + 1]:
break
low_color_index += 1
low_scale_val = scale[low_color_index]
high_scale_val = scale[low_color_index + 1]
face_color = clrs.find_intermediate_color(
colormap[low_color_index],
colormap[low_color_index + 1],
(t - low_scale_val) / (high_scale_val - low_scale_val),
)
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
def trisurf(
x,
y,
z,
simplices,
show_colorbar,
edges_color,
scale,
colormap=None,
color_func=None,
plot_edges=False,
x_edge=None,
y_edge=None,
z_edge=None,
facecolor=None,
):
"""
Refer to FigureFactory.create_trisurf() for docstring
"""
# numpy import check
if not np:
raise ImportError("FigureFactory._trisurf() requires " "numpy imported.")
points3D = np.vstack((x, y, z)).T
simplices = np.atleast_2d(simplices)
# vertices of the surface triangles
tri_vertices = points3D[simplices]
# Define colors for the triangle faces
if color_func is None:
# mean values of z-coordinates of triangle vertices
mean_dists = tri_vertices[:, :, 2].mean(-1)
elif isinstance(color_func, (list, np.ndarray)):
# Pre-computed list / array of values to map onto color
if len(color_func) != len(simplices):
raise ValueError(
"If color_func is a list/array, it must "
"be the same length as simplices."
)
# convert all colors in color_func to rgb
for index in range(len(color_func)):
if isinstance(color_func[index], str):
if "#" in color_func[index]:
foo = clrs.hex_to_rgb(color_func[index])
color_func[index] = clrs.label_rgb(foo)
if isinstance(color_func[index], tuple):
foo = clrs.convert_to_RGB_255(color_func[index])
color_func[index] = clrs.label_rgb(foo)
mean_dists = np.asarray(color_func)
else:
# apply user inputted function to calculate
# custom coloring for triangle vertices
mean_dists = []
for triangle in tri_vertices:
dists = []
for vertex in triangle:
dist = color_func(vertex[0], vertex[1], vertex[2])
dists.append(dist)
mean_dists.append(np.mean(dists))
mean_dists = np.asarray(mean_dists)
# Check if facecolors are already strings and can be skipped
if isinstance(mean_dists[0], str):
facecolor = mean_dists
else:
min_mean_dists = np.min(mean_dists)
max_mean_dists = np.max(mean_dists)
if facecolor is None:
facecolor = []
for index in range(len(mean_dists)):
color = map_face2color(
mean_dists[index], colormap, scale, min_mean_dists, max_mean_dists
)
facecolor.append(color)
# Make sure facecolor is a list so output is consistent across Pythons
facecolor = np.asarray(facecolor)
ii, jj, kk = simplices.T
triangles = graph_objs.Mesh3d(
x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name=""
)
mean_dists_are_numbers = not isinstance(mean_dists[0], str)
if mean_dists_are_numbers and show_colorbar is True:
# make a colorscale from the colors
colorscale = clrs.make_colorscale(colormap, scale)
colorscale = clrs.convert_colorscale_to_rgb(colorscale)
colorbar = graph_objs.Scatter3d(
x=x[:1],
y=y[:1],
z=z[:1],
mode="markers",
marker=dict(
size=0.1,
color=[min_mean_dists, max_mean_dists],
colorscale=colorscale,
showscale=True,
),
hoverinfo="none",
showlegend=False,
)
# the triangle sides are not plotted
if plot_edges is False:
if mean_dists_are_numbers and show_colorbar is True:
return [triangles, colorbar]
else:
return [triangles]
# define the lists x_edge, y_edge and z_edge, of x, y, resp z
# coordinates of edge end points for each triangle
# None separates data corresponding to two consecutive triangles
is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
if any(is_none):
if not all(is_none):
raise ValueError(
"If any (x_edge, y_edge, z_edge) is None, " "all must be None"
)
else:
x_edge = []
y_edge = []
z_edge = []
# Pull indices we care about, then add a None column to separate tris
ixs_triangles = [0, 1, 2, 0]
pull_edges = tri_vertices[:, ixs_triangles, :]
x_edge_pull = np.hstack(
[pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])]
)
y_edge_pull = np.hstack(
[pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])]
)
z_edge_pull = np.hstack(
[pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])]
)
# Now unravel the edges into a 1-d vector for plotting
x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
if not (len(x_edge) == len(y_edge) == len(z_edge)):
raise exceptions.PlotlyError(
"The lengths of x_edge, y_edge and " "z_edge are not the same."
)
# define the lines for plotting
lines = graph_objs.Scatter3d(
x=x_edge,
y=y_edge,
z=z_edge,
mode="lines",
line=graph_objs.scatter3d.Line(color=edges_color, width=1.5),
showlegend=False,
)
if mean_dists_are_numbers and show_colorbar is True:
return [triangles, lines, colorbar]
else:
return [triangles, lines]
def create_trisurf(
x,
y,
z,
simplices,
colormap=None,
show_colorbar=True,
scale=None,
color_func=None,
title="Trisurf Plot",
plot_edges=True,
showbackground=True,
backgroundcolor="rgb(230, 230, 230)",
gridcolor="rgb(255, 255, 255)",
zerolinecolor="rgb(255, 255, 255)",
edges_color="rgb(50, 50, 50)",
height=800,
width=800,
aspectratio=None,
):
"""
Returns figure for a triangulated surface plot
:param (array) x: data values of x in a 1D array
:param (array) y: data values of y in a 1D array
:param (array) z: data values of z in a 1D array
:param (array) simplices: an array of shape (ntri, 3) where ntri is
the number of triangles in the triangularization. Each row of the
array contains the indicies of the verticies of each triangle
:param (str|tuple|list) colormap: either a plotly scale name, an rgb
or hex color, a color tuple or a list of colors. An rgb color is
of the form 'rgb(x, y, z)' where x, y, z belong to the interval
[0, 255] and a color tuple is a tuple of the form (a, b, c) where
a, b and c belong to [0, 1]. If colormap is a list, it must
contain the valid color types aforementioned as its members
:param (bool) show_colorbar: determines if colorbar is visible
:param (list|array) scale: sets the scale values to be used if a non-
linearly interpolated colormap is desired. If left as None, a
linear interpolation between the colors will be excecuted
:param (function|list) color_func: The parameter that determines the
coloring of the surface. Takes either a function with 3 arguments
x, y, z or a list/array of color values the same length as
simplices. If None, coloring will only depend on the z axis
:param (str) title: title of the plot
:param (bool) plot_edges: determines if the triangles on the trisurf
are visible
:param (bool) showbackground: makes background in plot visible
:param (str) backgroundcolor: color of background. Takes a string of
the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
:param (str) gridcolor: color of the gridlines besides the axes. Takes
a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
inclusive
:param (str) zerolinecolor: color of the axes. Takes a string of the
form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
:param (str) edges_color: color of the edges, if plot_edges is True
:param (int|float) height: the height of the plot (in pixels)
:param (int|float) width: the width of the plot (in pixels)
:param (dict) aspectratio: a dictionary of the aspect ratio values for
the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
Example 1: Sphere
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 20)
>>> v = np.linspace(0, np.pi, 20)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> x = np.sin(v)*np.cos(u)
>>> y = np.sin(v)*np.sin(u)
>>> z = np.cos(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
... simplices=simplices)
Example 2: Torus
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 20)
>>> v = np.linspace(0, 2*np.pi, 20)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> x = (3 + (np.cos(v)))*np.cos(u)
>>> y = (3 + (np.cos(v)))*np.sin(u)
>>> z = np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
... simplices=simplices)
Example 3: Mobius Band
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 24)
>>> v = np.linspace(-1, 1, 8)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> tp = 1 + 0.5*v*np.cos(u/2.)
>>> x = tp*np.cos(u)
>>> y = tp*np.sin(u)
>>> z = 0.5*v*np.sin(u/2.)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
... simplices=simplices)
Example 4: Using a Custom Colormap Function with Light Cone
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u=np.linspace(-np.pi, np.pi, 30)
>>> v=np.linspace(-np.pi, np.pi, 30)
>>> u,v=np.meshgrid(u,v)
>>> u=u.flatten()
>>> v=v.flatten()
>>> x = u
>>> y = u*np.cos(v)
>>> z = u*np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Define distance function
>>> def dist_origin(x, y, z):
... return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z,
... colormap=['#FFFFFF', '#E4FFFE',
... '#A4F6F9', '#FF99FE',
... '#BA52ED'],
... scale=[0, 0.6, 0.71, 0.89, 1],
... simplices=simplices,
... color_func=dist_origin)
Example 5: Enter color_func as a list of colors
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> import random
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u=np.linspace(-np.pi, np.pi, 30)
>>> v=np.linspace(-np.pi, np.pi, 30)
>>> u,v=np.meshgrid(u,v)
>>> u=u.flatten()
>>> v=v.flatten()
>>> x = u
>>> y = u*np.cos(v)
>>> z = u*np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> colors = []
>>> color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
>>> for index in range(len(simplices)):
... colors.append(random.choice(color_choices))
>>> fig = create_trisurf(
... x, y, z, simplices,
... color_func=colors,
... show_colorbar=True,
... edges_color='rgb(2, 85, 180)',
... title=' Modern Art'
... )
"""
if aspectratio is None:
aspectratio = {"x": 1, "y": 1, "z": 1}
# Validate colormap
clrs.validate_colors(colormap)
colormap, scale = clrs.convert_colors_to_same_type(
colormap, colortype="tuple", return_default_colors=True, scale=scale
)
data1 = trisurf(
x,
y,
z,
simplices,
show_colorbar=show_colorbar,
color_func=color_func,
colormap=colormap,
scale=scale,
edges_color=edges_color,
plot_edges=plot_edges,
)
axis = dict(
showbackground=showbackground,
backgroundcolor=backgroundcolor,
gridcolor=gridcolor,
zerolinecolor=zerolinecolor,
)
layout = graph_objs.Layout(
title=title,
width=width,
height=height,
scene=graph_objs.layout.Scene(
xaxis=graph_objs.layout.scene.XAxis(**axis),
yaxis=graph_objs.layout.scene.YAxis(**axis),
zaxis=graph_objs.layout.scene.ZAxis(**axis),
aspectratio=dict(
x=aspectratio["x"], y=aspectratio["y"], z=aspectratio["z"]
),
),
)
return graph_objs.Figure(data=data1, layout=layout)
@@ -0,0 +1,712 @@
from __future__ import absolute_import
from numbers import Number
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
from plotly.subplots import make_subplots
pd = optional_imports.get_module("pandas")
np = optional_imports.get_module("numpy")
scipy_stats = optional_imports.get_module("scipy.stats")
def calc_stats(data):
"""
Calculate statistics for use in violin plot.
"""
x = np.asarray(data, np.float)
vals_min = np.min(x)
vals_max = np.max(x)
q2 = np.percentile(x, 50, interpolation="linear")
q1 = np.percentile(x, 25, interpolation="lower")
q3 = np.percentile(x, 75, interpolation="higher")
iqr = q3 - q1
whisker_dist = 1.5 * iqr
# in order to prevent drawing whiskers outside the interval
# of data one defines the whisker positions as:
d1 = np.min(x[x >= (q1 - whisker_dist)])
d2 = np.max(x[x <= (q3 + whisker_dist)])
return {
"min": vals_min,
"max": vals_max,
"q1": q1,
"q2": q2,
"q3": q3,
"d1": d1,
"d2": d2,
}
def make_half_violin(x, y, fillcolor="#1f77b4", linecolor="rgb(0, 0, 0)"):
"""
Produces a sideways probability distribution fig violin plot.
"""
text = [
"(pdf(y), y)=(" + "{:0.2f}".format(x[i]) + ", " + "{:0.2f}".format(y[i]) + ")"
for i in range(len(x))
]
return graph_objs.Scatter(
x=x,
y=y,
mode="lines",
name="",
text=text,
fill="tonextx",
fillcolor=fillcolor,
line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape="spline"),
hoverinfo="text",
opacity=0.5,
)
def make_violin_rugplot(vals, pdf_max, distance, color="#1f77b4"):
"""
Returns a rugplot fig for a violin plot.
"""
return graph_objs.Scatter(
y=vals,
x=[-pdf_max - distance] * len(vals),
marker=graph_objs.scatter.Marker(color=color, symbol="line-ew-open"),
mode="markers",
name="",
showlegend=False,
hoverinfo="y",
)
def make_non_outlier_interval(d1, d2):
"""
Returns the scatterplot fig of most of a violin plot.
"""
return graph_objs.Scatter(
x=[0, 0],
y=[d1, d2],
name="",
mode="lines",
line=graph_objs.scatter.Line(width=1.5, color="rgb(0,0,0)"),
)
def make_quartiles(q1, q3):
"""
Makes the upper and lower quartiles for a violin plot.
"""
return graph_objs.Scatter(
x=[0, 0],
y=[q1, q3],
text=[
"lower-quartile: " + "{:0.2f}".format(q1),
"upper-quartile: " + "{:0.2f}".format(q3),
],
mode="lines",
line=graph_objs.scatter.Line(width=4, color="rgb(0,0,0)"),
hoverinfo="text",
)
def make_median(q2):
"""
Formats the 'median' hovertext for a violin plot.
"""
return graph_objs.Scatter(
x=[0],
y=[q2],
text=["median: " + "{:0.2f}".format(q2)],
mode="markers",
marker=dict(symbol="square", color="rgb(255,255,255)"),
hoverinfo="text",
)
def make_XAxis(xaxis_title, xaxis_range):
"""
Makes the x-axis for a violin plot.
"""
xaxis = graph_objs.layout.XAxis(
title=xaxis_title,
range=xaxis_range,
showgrid=False,
zeroline=False,
showline=False,
mirror=False,
ticks="",
showticklabels=False,
)
return xaxis
def make_YAxis(yaxis_title):
"""
Makes the y-axis for a violin plot.
"""
yaxis = graph_objs.layout.YAxis(
title=yaxis_title,
showticklabels=True,
autorange=True,
ticklen=4,
showline=True,
zeroline=False,
showgrid=False,
mirror=False,
)
return yaxis
def violinplot(vals, fillcolor="#1f77b4", rugplot=True):
"""
Refer to FigureFactory.create_violin() for docstring.
"""
vals = np.asarray(vals, np.float)
# summary statistics
vals_min = calc_stats(vals)["min"]
vals_max = calc_stats(vals)["max"]
q1 = calc_stats(vals)["q1"]
q2 = calc_stats(vals)["q2"]
q3 = calc_stats(vals)["q3"]
d1 = calc_stats(vals)["d1"]
d2 = calc_stats(vals)["d2"]
# kernel density estimation of pdf
pdf = scipy_stats.gaussian_kde(vals)
# grid over the data interval
xx = np.linspace(vals_min, vals_max, 100)
# evaluate the pdf at the grid xx
yy = pdf(xx)
max_pdf = np.max(yy)
# distance from the violin plot to rugplot
distance = (2.0 * max_pdf) / 10 if rugplot else 0
# range for x values in the plot
plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
plot_data = [
make_half_violin(-yy, xx, fillcolor=fillcolor),
make_half_violin(yy, xx, fillcolor=fillcolor),
make_non_outlier_interval(d1, d2),
make_quartiles(q1, q3),
make_median(q2),
]
if rugplot:
plot_data.append(
make_violin_rugplot(vals, max_pdf, distance=distance, color=fillcolor)
)
return plot_data, plot_xrange
def violin_no_colorscale(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot without colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
color_index = 0
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], np.float)
if color_index >= len(colors):
color_index = 0
plot_data, plot_xrange = violinplot(
vals, fillcolor=colors[color_index], rugplot=rugplot
)
layout = graph_objs.Layout()
for item in plot_data:
fig.append_trace(item, 1, k + 1)
color_index += 1
# add violin plot labels
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def violin_colorscale(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot with colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
# make sure all group names are keys in group_stats
for group in group_name:
if group not in group_stats:
raise exceptions.PlotlyError(
"All values/groups in the index "
"column must be represented "
"as a key in group_stats."
)
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
# prepare low and high color for colorscale
lowcolor = clrs.color_parser(colors[0], clrs.unlabel_rgb)
highcolor = clrs.color_parser(colors[1], clrs.unlabel_rgb)
# find min and max values in group_stats
group_stats_values = []
for key in group_stats:
group_stats_values.append(group_stats[key])
max_value = max(group_stats_values)
min_value = min(group_stats_values)
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], np.float)
# find intermediate color from colorscale
intermed = (group_stats[gr] - min_value) / (max_value - min_value)
intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed)
plot_data, plot_xrange = violinplot(
vals, fillcolor="rgb{}".format(intermed_color), rugplot=rugplot
)
layout = graph_objs.Layout()
for item in plot_data:
fig.append_trace(item, 1, k + 1)
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# add colorbar to plot
trace_dummy = graph_objs.Scatter(
x=[0],
y=[0],
mode="markers",
marker=dict(
size=2,
cmin=min_value,
cmax=max_value,
colorscale=[[0, colors[0]], [1, colors[1]]],
showscale=True,
),
showlegend=False,
)
fig.append_trace(trace_dummy, 1, L)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def violin_dict(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot without colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
# check if all group names appear in colors dict
for group in group_name:
if group not in colors:
raise exceptions.PlotlyError(
"If colors is a dictionary, all "
"the group names must appear as "
"keys in colors."
)
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], np.float)
plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], rugplot=rugplot)
layout = graph_objs.Layout()
for item in plot_data:
fig.append_trace(item, 1, k + 1)
# add violin plot labels
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def create_violin(
data,
data_header=None,
group_header=None,
colors=None,
use_colorscale=False,
group_stats=None,
rugplot=True,
sort=False,
height=450,
width=600,
title="Violin and Rug Plot",
):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Violin`.
:param (list|array) data: accepts either a list of numerical values,
a list of dictionaries all with identical keys and at least one
column of numeric values, or a pandas dataframe with at least one
column of numbers.
:param (str) data_header: the header of the data column to be used
from an inputted pandas dataframe. Not applicable if 'data' is
a list of numeric values.
:param (str) group_header: applicable if grouping data by a variable.
'group_header' must be set to the name of the grouping variable.
:param (str|tuple|list|dict) colors: either a plotly scale name,
an rgb or hex color, a color tuple, a list of colors or a
dictionary. An rgb color is of the form 'rgb(x, y, z)' where
x, y and z belong to the interval [0, 255] and a color tuple is a
tuple of the form (a, b, c) where a, b and c belong to [0, 1].
If colors is a list, it must contain valid color types as its
members.
:param (bool) use_colorscale: only applicable if grouping by another
variable. Will implement a colorscale based on the first 2 colors
of param colors. This means colors must be a list with at least 2
colors in it (Plotly colorscales are accepted since they map to a
list of two rgb colors). Default = False
:param (dict) group_stats: a dictionary where each key is a unique
value from the group_header column in data. Each value must be a
number and will be used to color the violin plots if a colorscale
is being used.
:param (bool) rugplot: determines if a rugplot is draw on violin plot.
Default = True
:param (bool) sort: determines if violins are sorted
alphabetically (True) or by input order (False). Default = False
:param (float) height: the height of the violin plot.
:param (float) width: the width of the violin plot.
:param (str) title: the title of the violin plot.
Example 1: Single Violin Plot
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> from scipy import stats
>>> # create list of random values
>>> data_list = np.random.randn(100)
>>> # create violin fig
>>> fig = create_violin(data_list, colors='#604d9e')
>>> # plot
>>> fig.show()
Example 2: Multiple Violin Plots with Qualitative Coloring
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> import pandas as pd
>>> from scipy import stats
>>> # create dataframe
>>> np.random.seed(619517)
>>> Nr=250
>>> y = np.random.randn(Nr)
>>> gr = np.random.choice(list("ABCDE"), Nr)
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
>>> for i, letter in enumerate("ABCDE"):
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
>>> # create violin fig
>>> fig = create_violin(df, data_header='Score', group_header='Group',
... sort=True, height=600, width=1000)
>>> # plot
>>> fig.show()
Example 3: Violin Plots with Colorscale
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> import pandas as pd
>>> from scipy import stats
>>> # create dataframe
>>> np.random.seed(619517)
>>> Nr=250
>>> y = np.random.randn(Nr)
>>> gr = np.random.choice(list("ABCDE"), Nr)
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
>>> for i, letter in enumerate("ABCDE"):
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
>>> # define header params
>>> data_header = 'Score'
>>> group_header = 'Group'
>>> # make groupby object with pandas
>>> group_stats = {}
>>> groupby_data = df.groupby([group_header])
>>> for group in "ABCDE":
... data_from_group = groupby_data.get_group(group)[data_header]
... # take a stat of the grouped data
... stat = np.median(data_from_group)
... # add to dictionary
... group_stats[group] = stat
>>> # create violin fig
>>> fig = create_violin(df, data_header='Score', group_header='Group',
... height=600, width=1000, use_colorscale=True,
... group_stats=group_stats)
>>> # plot
>>> fig.show()
"""
# Validate colors
if isinstance(colors, dict):
valid_colors = clrs.validate_colors_dict(colors, "rgb")
else:
valid_colors = clrs.validate_colors(colors, "rgb")
# validate data and choose plot type
if group_header is None:
if isinstance(data, list):
if len(data) <= 0:
raise exceptions.PlotlyError(
"If data is a list, it must be "
"nonempty and contain either "
"numbers or dictionaries."
)
if not all(isinstance(element, Number) for element in data):
raise exceptions.PlotlyError(
"If data is a list, it must " "contain only numbers."
)
if pd and isinstance(data, pd.core.frame.DataFrame):
if data_header is None:
raise exceptions.PlotlyError(
"data_header must be the "
"column name with the "
"desired numeric data for "
"the violin plot."
)
data = data[data_header].values.tolist()
# call the plotting functions
plot_data, plot_xrange = violinplot(
data, fillcolor=valid_colors[0], rugplot=rugplot
)
layout = graph_objs.Layout(
title=title,
autosize=False,
font=graph_objs.layout.Font(size=11),
height=height,
showlegend=False,
width=width,
xaxis=make_XAxis("", plot_xrange),
yaxis=make_YAxis(""),
hovermode="closest",
)
layout["yaxis"].update(dict(showline=False, showticklabels=False, ticks=""))
fig = graph_objs.Figure(data=plot_data, layout=layout)
return fig
else:
if not isinstance(data, pd.core.frame.DataFrame):
raise exceptions.PlotlyError(
"Error. You must use a pandas "
"DataFrame if you are using a "
"group header."
)
if data_header is None:
raise exceptions.PlotlyError(
"data_header must be the column "
"name with the desired numeric "
"data for the violin plot."
)
if use_colorscale is False:
if isinstance(valid_colors, dict):
# validate colors dict choice below
fig = violin_dict(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig
else:
fig = violin_no_colorscale(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig
else:
if isinstance(valid_colors, dict):
raise exceptions.PlotlyError(
"The colors param cannot be "
"a dictionary if you are "
"using a colorscale."
)
if len(valid_colors) < 2:
raise exceptions.PlotlyError(
"colors must be a list with "
"at least 2 colors. A "
"Plotly scale is allowed."
)
if not isinstance(group_stats, dict):
raise exceptions.PlotlyError(
"Your group_stats param " "must be a dictionary."
)
fig = violin_colorscale(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig
@@ -0,0 +1,274 @@
from __future__ import absolute_import
import decimal
from plotly import exceptions
from plotly.colors import (
DEFAULT_PLOTLY_COLORS,
PLOTLY_SCALES,
color_parser,
colorscale_to_colors,
colorscale_to_scale,
convert_to_RGB_255,
find_intermediate_color,
hex_to_rgb,
label_rgb,
n_colors,
unconvert_from_RGB_255,
unlabel_rgb,
validate_colors,
validate_colors_dict,
validate_colorscale,
validate_scale_values,
)
try:
from collections.abc import Sequence
except ImportError:
from collections import Sequence
def is_sequence(obj):
return isinstance(obj, Sequence) and not isinstance(obj, str)
def validate_index(index_vals):
"""
Validates if a list contains all numbers or all strings
:raises: (PlotlyError) If there are any two items in the list whose
types differ
"""
from numbers import Number
if isinstance(index_vals[0], Number):
if not all(isinstance(item, Number) for item in index_vals):
raise exceptions.PlotlyError(
"Error in indexing column. "
"Make sure all entries of each "
"column are all numbers or "
"all strings."
)
elif isinstance(index_vals[0], str):
if not all(isinstance(item, str) for item in index_vals):
raise exceptions.PlotlyError(
"Error in indexing column. "
"Make sure all entries of each "
"column are all numbers or "
"all strings."
)
def validate_dataframe(array):
"""
Validates all strings or numbers in each dataframe column
:raises: (PlotlyError) If there are any two items in any list whose
types differ
"""
from numbers import Number
for vector in array:
if isinstance(vector[0], Number):
if not all(isinstance(item, Number) for item in vector):
raise exceptions.PlotlyError(
"Error in dataframe. "
"Make sure all entries of "
"each column are either "
"numbers or strings."
)
elif isinstance(vector[0], str):
if not all(isinstance(item, str) for item in vector):
raise exceptions.PlotlyError(
"Error in dataframe. "
"Make sure all entries of "
"each column are either "
"numbers or strings."
)
def validate_equal_length(*args):
"""
Validates that data lists or ndarrays are the same length.
:raises: (PlotlyError) If any data lists are not the same length.
"""
length = len(args[0])
if any(len(lst) != length for lst in args):
raise exceptions.PlotlyError(
"Oops! Your data lists or ndarrays " "should be the same length."
)
def validate_positive_scalars(**kwargs):
"""
Validates that all values given in key/val pairs are positive.
Accepts kwargs to improve Exception messages.
:raises: (PlotlyError) If any value is < 0 or raises.
"""
for key, val in kwargs.items():
try:
if val <= 0:
raise ValueError("{} must be > 0, got {}".format(key, val))
except TypeError:
raise exceptions.PlotlyError("{} must be a number, got {}".format(key, val))
def flatten(array):
"""
Uses list comprehension to flatten array
:param (array): An iterable to flatten
:raises (PlotlyError): If iterable is not nested.
:rtype (list): The flattened list.
"""
try:
return [item for sublist in array for item in sublist]
except TypeError:
raise exceptions.PlotlyError(
"Your data array could not be "
"flattened! Make sure your data is "
"entered as lists or ndarrays!"
)
def endpts_to_intervals(endpts):
"""
Returns a list of intervals for categorical colormaps
Accepts a list or tuple of sequentially increasing numbers and returns
a list representation of the mathematical intervals with these numbers
as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
:raises: (PlotlyError) If input is not a list or tuple
:raises: (PlotlyError) If the input contains a string
:raises: (PlotlyError) If any number does not increase after the
previous one in the sequence
"""
length = len(endpts)
# Check if endpts is a list or tuple
if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
raise exceptions.PlotlyError(
"The intervals_endpts argument must "
"be a list or tuple of a sequence "
"of increasing numbers."
)
# Check if endpts contains only numbers
for item in endpts:
if isinstance(item, str):
raise exceptions.PlotlyError(
"The intervals_endpts argument "
"must be a list or tuple of a "
"sequence of increasing "
"numbers."
)
# Check if numbers in endpts are increasing
for k in range(length - 1):
if endpts[k] >= endpts[k + 1]:
raise exceptions.PlotlyError(
"The intervals_endpts argument "
"must be a list or tuple of a "
"sequence of increasing "
"numbers."
)
else:
intervals = []
# add -inf to intervals
intervals.append([float("-inf"), endpts[0]])
for k in range(length - 1):
interval = []
interval.append(endpts[k])
interval.append(endpts[k + 1])
intervals.append(interval)
# add +inf to intervals
intervals.append([endpts[length - 1], float("inf")])
return intervals
def annotation_dict_for_label(
text,
lane,
num_of_lanes,
subplot_spacing,
row_col="col",
flipped=True,
right_side=True,
text_color="#0f0f0f",
):
"""
Returns annotation dict for label of n labels of a 1xn or nx1 subplot.
:param (str) text: the text for a label.
:param (int) lane: the label number for text. From 1 to n inclusive.
:param (int) num_of_lanes: the number 'n' of rows or columns in subplot.
:param (float) subplot_spacing: the value for the horizontal_spacing and
vertical_spacing params in your plotly.tools.make_subplots() call.
:param (str) row_col: choose whether labels are placed along rows or
columns.
:param (bool) flipped: flips text by 90 degrees. Text is printed
horizontally if set to True and row_col='row', or if False and
row_col='col'.
:param (bool) right_side: only applicable if row_col is set to 'row'.
:param (str) text_color: color of the text.
"""
l = (1 - (num_of_lanes - 1) * subplot_spacing) / (num_of_lanes)
if not flipped:
xanchor = "center"
yanchor = "middle"
if row_col == "col":
x = (lane - 1) * (l + subplot_spacing) + 0.5 * l
y = 1.03
textangle = 0
elif row_col == "row":
y = (lane - 1) * (l + subplot_spacing) + 0.5 * l
x = 1.03
textangle = 90
else:
if row_col == "col":
xanchor = "center"
yanchor = "bottom"
x = (lane - 1) * (l + subplot_spacing) + 0.5 * l
y = 1.0
textangle = 270
elif row_col == "row":
yanchor = "middle"
y = (lane - 1) * (l + subplot_spacing) + 0.5 * l
if right_side:
x = 1.0
xanchor = "left"
else:
x = -0.01
xanchor = "right"
textangle = 0
annotation_dict = dict(
textangle=textangle,
xanchor=xanchor,
yanchor=yanchor,
x=x,
y=y,
showarrow=False,
xref="paper",
yref="paper",
text=text,
font=dict(size=13, color=text_color),
)
return annotation_dict
def list_of_options(iterable, conj="and", period=True):
"""
Returns an English listing of objects seperated by commas ','
For example, ['foo', 'bar', 'baz'] becomes 'foo, bar and baz'
if the conjunction 'and' is selected.
"""
if len(iterable) < 2:
raise exceptions.PlotlyError(
"Your list or tuple must contain at least 2 items."
)
template = (len(iterable) - 2) * "{}, " + "{} " + conj + " {}" + period * "."
return template.format(*iterable)
+2
View File
@@ -0,0 +1,2 @@
from __future__ import absolute_import
from _plotly_utils.files import *
@@ -0,0 +1,299 @@
import sys
from typing import TYPE_CHECKING
if sys.version_info < (3, 7) or TYPE_CHECKING:
from ..graph_objs import Waterfall
from ..graph_objs import Volume
from ..graph_objs import Violin
from ..graph_objs import Treemap
from ..graph_objs import Table
from ..graph_objs import Surface
from ..graph_objs import Sunburst
from ..graph_objs import Streamtube
from ..graph_objs import Splom
from ..graph_objs import Scatterternary
from ..graph_objs import Scattersmith
from ..graph_objs import Scatterpolargl
from ..graph_objs import Scatterpolar
from ..graph_objs import Scattermapbox
from ..graph_objs import Scattergl
from ..graph_objs import Scattergeo
from ..graph_objs import Scattercarpet
from ..graph_objs import Scatter3d
from ..graph_objs import Scatter
from ..graph_objs import Sankey
from ..graph_objs import Pointcloud
from ..graph_objs import Pie
from ..graph_objs import Parcoords
from ..graph_objs import Parcats
from ..graph_objs import Ohlc
from ..graph_objs import Mesh3d
from ..graph_objs import Isosurface
from ..graph_objs import Indicator
from ..graph_objs import Image
from ..graph_objs import Icicle
from ..graph_objs import Histogram2dContour
from ..graph_objs import Histogram2d
from ..graph_objs import Histogram
from ..graph_objs import Heatmapgl
from ..graph_objs import Heatmap
from ..graph_objs import Funnelarea
from ..graph_objs import Funnel
from ..graph_objs import Densitymapbox
from ..graph_objs import Contourcarpet
from ..graph_objs import Contour
from ..graph_objs import Cone
from ..graph_objs import Choroplethmapbox
from ..graph_objs import Choropleth
from ..graph_objs import Carpet
from ..graph_objs import Candlestick
from ..graph_objs import Box
from ..graph_objs import Barpolar
from ..graph_objs import Bar
from ..graph_objs import Layout
from ..graph_objs import Frame
from ..graph_objs import Figure
from ..graph_objs import Data
from ..graph_objs import Annotations
from ..graph_objs import Frames
from ..graph_objs import AngularAxis
from ..graph_objs import Annotation
from ..graph_objs import ColorBar
from ..graph_objs import Contours
from ..graph_objs import ErrorX
from ..graph_objs import ErrorY
from ..graph_objs import ErrorZ
from ..graph_objs import Font
from ..graph_objs import Legend
from ..graph_objs import Line
from ..graph_objs import Margin
from ..graph_objs import Marker
from ..graph_objs import RadialAxis
from ..graph_objs import Scene
from ..graph_objs import Stream
from ..graph_objs import XAxis
from ..graph_objs import YAxis
from ..graph_objs import ZAxis
from ..graph_objs import XBins
from ..graph_objs import YBins
from ..graph_objs import Trace
from ..graph_objs import Histogram2dcontour
from ..graph_objs import waterfall
from ..graph_objs import volume
from ..graph_objs import violin
from ..graph_objs import treemap
from ..graph_objs import table
from ..graph_objs import surface
from ..graph_objs import sunburst
from ..graph_objs import streamtube
from ..graph_objs import splom
from ..graph_objs import scatterternary
from ..graph_objs import scattersmith
from ..graph_objs import scatterpolargl
from ..graph_objs import scatterpolar
from ..graph_objs import scattermapbox
from ..graph_objs import scattergl
from ..graph_objs import scattergeo
from ..graph_objs import scattercarpet
from ..graph_objs import scatter3d
from ..graph_objs import scatter
from ..graph_objs import sankey
from ..graph_objs import pointcloud
from ..graph_objs import pie
from ..graph_objs import parcoords
from ..graph_objs import parcats
from ..graph_objs import ohlc
from ..graph_objs import mesh3d
from ..graph_objs import isosurface
from ..graph_objs import indicator
from ..graph_objs import image
from ..graph_objs import icicle
from ..graph_objs import histogram2dcontour
from ..graph_objs import histogram2d
from ..graph_objs import histogram
from ..graph_objs import heatmapgl
from ..graph_objs import heatmap
from ..graph_objs import funnelarea
from ..graph_objs import funnel
from ..graph_objs import densitymapbox
from ..graph_objs import contourcarpet
from ..graph_objs import contour
from ..graph_objs import cone
from ..graph_objs import choroplethmapbox
from ..graph_objs import choropleth
from ..graph_objs import carpet
from ..graph_objs import candlestick
from ..graph_objs import box
from ..graph_objs import barpolar
from ..graph_objs import bar
from ..graph_objs import layout
else:
from _plotly_utils.importers import relative_import
__all__, __getattr__, __dir__ = relative_import(
__name__,
[
"..graph_objs.waterfall",
"..graph_objs.volume",
"..graph_objs.violin",
"..graph_objs.treemap",
"..graph_objs.table",
"..graph_objs.surface",
"..graph_objs.sunburst",
"..graph_objs.streamtube",
"..graph_objs.splom",
"..graph_objs.scatterternary",
"..graph_objs.scattersmith",
"..graph_objs.scatterpolargl",
"..graph_objs.scatterpolar",
"..graph_objs.scattermapbox",
"..graph_objs.scattergl",
"..graph_objs.scattergeo",
"..graph_objs.scattercarpet",
"..graph_objs.scatter3d",
"..graph_objs.scatter",
"..graph_objs.sankey",
"..graph_objs.pointcloud",
"..graph_objs.pie",
"..graph_objs.parcoords",
"..graph_objs.parcats",
"..graph_objs.ohlc",
"..graph_objs.mesh3d",
"..graph_objs.isosurface",
"..graph_objs.indicator",
"..graph_objs.image",
"..graph_objs.icicle",
"..graph_objs.histogram2dcontour",
"..graph_objs.histogram2d",
"..graph_objs.histogram",
"..graph_objs.heatmapgl",
"..graph_objs.heatmap",
"..graph_objs.funnelarea",
"..graph_objs.funnel",
"..graph_objs.densitymapbox",
"..graph_objs.contourcarpet",
"..graph_objs.contour",
"..graph_objs.cone",
"..graph_objs.choroplethmapbox",
"..graph_objs.choropleth",
"..graph_objs.carpet",
"..graph_objs.candlestick",
"..graph_objs.box",
"..graph_objs.barpolar",
"..graph_objs.bar",
"..graph_objs.layout",
],
[
"..graph_objs.Waterfall",
"..graph_objs.Volume",
"..graph_objs.Violin",
"..graph_objs.Treemap",
"..graph_objs.Table",
"..graph_objs.Surface",
"..graph_objs.Sunburst",
"..graph_objs.Streamtube",
"..graph_objs.Splom",
"..graph_objs.Scatterternary",
"..graph_objs.Scattersmith",
"..graph_objs.Scatterpolargl",
"..graph_objs.Scatterpolar",
"..graph_objs.Scattermapbox",
"..graph_objs.Scattergl",
"..graph_objs.Scattergeo",
"..graph_objs.Scattercarpet",
"..graph_objs.Scatter3d",
"..graph_objs.Scatter",
"..graph_objs.Sankey",
"..graph_objs.Pointcloud",
"..graph_objs.Pie",
"..graph_objs.Parcoords",
"..graph_objs.Parcats",
"..graph_objs.Ohlc",
"..graph_objs.Mesh3d",
"..graph_objs.Isosurface",
"..graph_objs.Indicator",
"..graph_objs.Image",
"..graph_objs.Icicle",
"..graph_objs.Histogram2dContour",
"..graph_objs.Histogram2d",
"..graph_objs.Histogram",
"..graph_objs.Heatmapgl",
"..graph_objs.Heatmap",
"..graph_objs.Funnelarea",
"..graph_objs.Funnel",
"..graph_objs.Densitymapbox",
"..graph_objs.Contourcarpet",
"..graph_objs.Contour",
"..graph_objs.Cone",
"..graph_objs.Choroplethmapbox",
"..graph_objs.Choropleth",
"..graph_objs.Carpet",
"..graph_objs.Candlestick",
"..graph_objs.Box",
"..graph_objs.Barpolar",
"..graph_objs.Bar",
"..graph_objs.Layout",
"..graph_objs.Frame",
"..graph_objs.Figure",
"..graph_objs.Data",
"..graph_objs.Annotations",
"..graph_objs.Frames",
"..graph_objs.AngularAxis",
"..graph_objs.Annotation",
"..graph_objs.ColorBar",
"..graph_objs.Contours",
"..graph_objs.ErrorX",
"..graph_objs.ErrorY",
"..graph_objs.ErrorZ",
"..graph_objs.Font",
"..graph_objs.Legend",
"..graph_objs.Line",
"..graph_objs.Margin",
"..graph_objs.Marker",
"..graph_objs.RadialAxis",
"..graph_objs.Scene",
"..graph_objs.Stream",
"..graph_objs.XAxis",
"..graph_objs.YAxis",
"..graph_objs.ZAxis",
"..graph_objs.XBins",
"..graph_objs.YBins",
"..graph_objs.Trace",
"..graph_objs.Histogram2dcontour",
],
)
if sys.version_info < (3, 7) or TYPE_CHECKING:
try:
import ipywidgets as _ipywidgets
from distutils.version import LooseVersion as _LooseVersion
if _LooseVersion(_ipywidgets.__version__) >= _LooseVersion("7.0.0"):
from ..graph_objs._figurewidget import FigureWidget
else:
raise ImportError()
except Exception:
from ..missing_ipywidgets import FigureWidget
else:
__all__.append("FigureWidget")
orig_getattr = __getattr__
def __getattr__(import_name):
if import_name == "FigureWidget":
try:
import ipywidgets
from distutils.version import LooseVersion
if LooseVersion(ipywidgets.__version__) >= LooseVersion("7.0.0"):
from ..graph_objs._figurewidget import FigureWidget
return FigureWidget
else:
raise ImportError()
except Exception:
from ..missing_ipywidgets import FigureWidget
return FigureWidget
return orig_getattr(import_name)
@@ -0,0 +1,299 @@
import sys
from typing import TYPE_CHECKING
if sys.version_info < (3, 7) or TYPE_CHECKING:
from ._bar import Bar
from ._barpolar import Barpolar
from ._box import Box
from ._candlestick import Candlestick
from ._carpet import Carpet
from ._choropleth import Choropleth
from ._choroplethmapbox import Choroplethmapbox
from ._cone import Cone
from ._contour import Contour
from ._contourcarpet import Contourcarpet
from ._densitymapbox import Densitymapbox
from ._deprecations import AngularAxis
from ._deprecations import Annotation
from ._deprecations import Annotations
from ._deprecations import ColorBar
from ._deprecations import Contours
from ._deprecations import Data
from ._deprecations import ErrorX
from ._deprecations import ErrorY
from ._deprecations import ErrorZ
from ._deprecations import Font
from ._deprecations import Frames
from ._deprecations import Histogram2dcontour
from ._deprecations import Legend
from ._deprecations import Line
from ._deprecations import Margin
from ._deprecations import Marker
from ._deprecations import RadialAxis
from ._deprecations import Scene
from ._deprecations import Stream
from ._deprecations import Trace
from ._deprecations import XAxis
from ._deprecations import XBins
from ._deprecations import YAxis
from ._deprecations import YBins
from ._deprecations import ZAxis
from ._figure import Figure
from ._frame import Frame
from ._funnel import Funnel
from ._funnelarea import Funnelarea
from ._heatmap import Heatmap
from ._heatmapgl import Heatmapgl
from ._histogram import Histogram
from ._histogram2d import Histogram2d
from ._histogram2dcontour import Histogram2dContour
from ._icicle import Icicle
from ._image import Image
from ._indicator import Indicator
from ._isosurface import Isosurface
from ._layout import Layout
from ._mesh3d import Mesh3d
from ._ohlc import Ohlc
from ._parcats import Parcats
from ._parcoords import Parcoords
from ._pie import Pie
from ._pointcloud import Pointcloud
from ._sankey import Sankey
from ._scatter import Scatter
from ._scatter3d import Scatter3d
from ._scattercarpet import Scattercarpet
from ._scattergeo import Scattergeo
from ._scattergl import Scattergl
from ._scattermapbox import Scattermapbox
from ._scatterpolar import Scatterpolar
from ._scatterpolargl import Scatterpolargl
from ._scattersmith import Scattersmith
from ._scatterternary import Scatterternary
from ._splom import Splom
from ._streamtube import Streamtube
from ._sunburst import Sunburst
from ._surface import Surface
from ._table import Table
from ._treemap import Treemap
from ._violin import Violin
from ._volume import Volume
from ._waterfall import Waterfall
from . import bar
from . import barpolar
from . import box
from . import candlestick
from . import carpet
from . import choropleth
from . import choroplethmapbox
from . import cone
from . import contour
from . import contourcarpet
from . import densitymapbox
from . import funnel
from . import funnelarea
from . import heatmap
from . import heatmapgl
from . import histogram
from . import histogram2d
from . import histogram2dcontour
from . import icicle
from . import image
from . import indicator
from . import isosurface
from . import layout
from . import mesh3d
from . import ohlc
from . import parcats
from . import parcoords
from . import pie
from . import pointcloud
from . import sankey
from . import scatter
from . import scatter3d
from . import scattercarpet
from . import scattergeo
from . import scattergl
from . import scattermapbox
from . import scatterpolar
from . import scatterpolargl
from . import scattersmith
from . import scatterternary
from . import splom
from . import streamtube
from . import sunburst
from . import surface
from . import table
from . import treemap
from . import violin
from . import volume
from . import waterfall
else:
from _plotly_utils.importers import relative_import
__all__, __getattr__, __dir__ = relative_import(
__name__,
[
".bar",
".barpolar",
".box",
".candlestick",
".carpet",
".choropleth",
".choroplethmapbox",
".cone",
".contour",
".contourcarpet",
".densitymapbox",
".funnel",
".funnelarea",
".heatmap",
".heatmapgl",
".histogram",
".histogram2d",
".histogram2dcontour",
".icicle",
".image",
".indicator",
".isosurface",
".layout",
".mesh3d",
".ohlc",
".parcats",
".parcoords",
".pie",
".pointcloud",
".sankey",
".scatter",
".scatter3d",
".scattercarpet",
".scattergeo",
".scattergl",
".scattermapbox",
".scatterpolar",
".scatterpolargl",
".scattersmith",
".scatterternary",
".splom",
".streamtube",
".sunburst",
".surface",
".table",
".treemap",
".violin",
".volume",
".waterfall",
],
[
"._bar.Bar",
"._barpolar.Barpolar",
"._box.Box",
"._candlestick.Candlestick",
"._carpet.Carpet",
"._choropleth.Choropleth",
"._choroplethmapbox.Choroplethmapbox",
"._cone.Cone",
"._contour.Contour",
"._contourcarpet.Contourcarpet",
"._densitymapbox.Densitymapbox",
"._deprecations.AngularAxis",
"._deprecations.Annotation",
"._deprecations.Annotations",
"._deprecations.ColorBar",
"._deprecations.Contours",
"._deprecations.Data",
"._deprecations.ErrorX",
"._deprecations.ErrorY",
"._deprecations.ErrorZ",
"._deprecations.Font",
"._deprecations.Frames",
"._deprecations.Histogram2dcontour",
"._deprecations.Legend",
"._deprecations.Line",
"._deprecations.Margin",
"._deprecations.Marker",
"._deprecations.RadialAxis",
"._deprecations.Scene",
"._deprecations.Stream",
"._deprecations.Trace",
"._deprecations.XAxis",
"._deprecations.XBins",
"._deprecations.YAxis",
"._deprecations.YBins",
"._deprecations.ZAxis",
"._figure.Figure",
"._frame.Frame",
"._funnel.Funnel",
"._funnelarea.Funnelarea",
"._heatmap.Heatmap",
"._heatmapgl.Heatmapgl",
"._histogram.Histogram",
"._histogram2d.Histogram2d",
"._histogram2dcontour.Histogram2dContour",
"._icicle.Icicle",
"._image.Image",
"._indicator.Indicator",
"._isosurface.Isosurface",
"._layout.Layout",
"._mesh3d.Mesh3d",
"._ohlc.Ohlc",
"._parcats.Parcats",
"._parcoords.Parcoords",
"._pie.Pie",
"._pointcloud.Pointcloud",
"._sankey.Sankey",
"._scatter.Scatter",
"._scatter3d.Scatter3d",
"._scattercarpet.Scattercarpet",
"._scattergeo.Scattergeo",
"._scattergl.Scattergl",
"._scattermapbox.Scattermapbox",
"._scatterpolar.Scatterpolar",
"._scatterpolargl.Scatterpolargl",
"._scattersmith.Scattersmith",
"._scatterternary.Scatterternary",
"._splom.Splom",
"._streamtube.Streamtube",
"._sunburst.Sunburst",
"._surface.Surface",
"._table.Table",
"._treemap.Treemap",
"._violin.Violin",
"._volume.Volume",
"._waterfall.Waterfall",
],
)
if sys.version_info < (3, 7) or TYPE_CHECKING:
try:
import ipywidgets as _ipywidgets
from distutils.version import LooseVersion as _LooseVersion
if _LooseVersion(_ipywidgets.__version__) >= _LooseVersion("7.0.0"):
from ..graph_objs._figurewidget import FigureWidget
else:
raise ImportError()
except Exception:
from ..missing_ipywidgets import FigureWidget
else:
__all__.append("FigureWidget")
orig_getattr = __getattr__
def __getattr__(import_name):
if import_name == "FigureWidget":
try:
import ipywidgets
from distutils.version import LooseVersion
if LooseVersion(ipywidgets.__version__) >= LooseVersion("7.0.0"):
from ..graph_objs._figurewidget import FigureWidget
return FigureWidget
else:
raise ImportError()
except Exception:
from ..missing_ipywidgets import FigureWidget
return FigureWidget
return orig_getattr(import_name)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,723 @@
import warnings
warnings.filterwarnings(
"default", r"plotly\.graph_objs\.\w+ is deprecated", DeprecationWarning
)
class Data(list):
"""
plotly.graph_objs.Data is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Data is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Data is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
""",
DeprecationWarning,
)
super(Data, self).__init__(*args, **kwargs)
class Annotations(list):
"""
plotly.graph_objs.Annotations is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Annotations is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
"""
warnings.warn(
"""plotly.graph_objs.Annotations is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
""",
DeprecationWarning,
)
super(Annotations, self).__init__(*args, **kwargs)
class Frames(list):
"""
plotly.graph_objs.Frames is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Frame
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Frames is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Frame
"""
warnings.warn(
"""plotly.graph_objs.Frames is deprecated.
Please replace it with a list or tuple of instances of the following types
- plotly.graph_objs.Frame
""",
DeprecationWarning,
)
super(Frames, self).__init__(*args, **kwargs)
class AngularAxis(dict):
"""
plotly.graph_objs.AngularAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.AngularAxis
- plotly.graph_objs.layout.polar.AngularAxis
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.AngularAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.AngularAxis
- plotly.graph_objs.layout.polar.AngularAxis
"""
warnings.warn(
"""plotly.graph_objs.AngularAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.AngularAxis
- plotly.graph_objs.layout.polar.AngularAxis
""",
DeprecationWarning,
)
super(AngularAxis, self).__init__(*args, **kwargs)
class Annotation(dict):
"""
plotly.graph_objs.Annotation is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Annotation is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
"""
warnings.warn(
"""plotly.graph_objs.Annotation is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Annotation
- plotly.graph_objs.layout.scene.Annotation
""",
DeprecationWarning,
)
super(Annotation, self).__init__(*args, **kwargs)
class ColorBar(dict):
"""
plotly.graph_objs.ColorBar is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.marker.ColorBar
- plotly.graph_objs.surface.ColorBar
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.ColorBar is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.marker.ColorBar
- plotly.graph_objs.surface.ColorBar
- etc.
"""
warnings.warn(
"""plotly.graph_objs.ColorBar is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.marker.ColorBar
- plotly.graph_objs.surface.ColorBar
- etc.
""",
DeprecationWarning,
)
super(ColorBar, self).__init__(*args, **kwargs)
class Contours(dict):
"""
plotly.graph_objs.Contours is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.contour.Contours
- plotly.graph_objs.surface.Contours
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Contours is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.contour.Contours
- plotly.graph_objs.surface.Contours
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Contours is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.contour.Contours
- plotly.graph_objs.surface.Contours
- etc.
""",
DeprecationWarning,
)
super(Contours, self).__init__(*args, **kwargs)
class ErrorX(dict):
"""
plotly.graph_objs.ErrorX is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorX
- plotly.graph_objs.histogram.ErrorX
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.ErrorX is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorX
- plotly.graph_objs.histogram.ErrorX
- etc.
"""
warnings.warn(
"""plotly.graph_objs.ErrorX is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorX
- plotly.graph_objs.histogram.ErrorX
- etc.
""",
DeprecationWarning,
)
super(ErrorX, self).__init__(*args, **kwargs)
class ErrorY(dict):
"""
plotly.graph_objs.ErrorY is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorY
- plotly.graph_objs.histogram.ErrorY
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.ErrorY is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorY
- plotly.graph_objs.histogram.ErrorY
- etc.
"""
warnings.warn(
"""plotly.graph_objs.ErrorY is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.ErrorY
- plotly.graph_objs.histogram.ErrorY
- etc.
""",
DeprecationWarning,
)
super(ErrorY, self).__init__(*args, **kwargs)
class ErrorZ(dict):
"""
plotly.graph_objs.ErrorZ is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter3d.ErrorZ
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.ErrorZ is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter3d.ErrorZ
"""
warnings.warn(
"""plotly.graph_objs.ErrorZ is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter3d.ErrorZ
""",
DeprecationWarning,
)
super(ErrorZ, self).__init__(*args, **kwargs)
class Font(dict):
"""
plotly.graph_objs.Font is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Font
- plotly.graph_objs.layout.hoverlabel.Font
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Font is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Font
- plotly.graph_objs.layout.hoverlabel.Font
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Font is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Font
- plotly.graph_objs.layout.hoverlabel.Font
- etc.
""",
DeprecationWarning,
)
super(Font, self).__init__(*args, **kwargs)
class Legend(dict):
"""
plotly.graph_objs.Legend is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Legend
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Legend is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Legend
"""
warnings.warn(
"""plotly.graph_objs.Legend is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Legend
""",
DeprecationWarning,
)
super(Legend, self).__init__(*args, **kwargs)
class Line(dict):
"""
plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Line
- plotly.graph_objs.layout.shape.Line
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Line
- plotly.graph_objs.layout.shape.Line
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Line
- plotly.graph_objs.layout.shape.Line
- etc.
""",
DeprecationWarning,
)
super(Line, self).__init__(*args, **kwargs)
class Margin(dict):
"""
plotly.graph_objs.Margin is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Margin
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Margin is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Margin
"""
warnings.warn(
"""plotly.graph_objs.Margin is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Margin
""",
DeprecationWarning,
)
super(Margin, self).__init__(*args, **kwargs)
class Marker(dict):
"""
plotly.graph_objs.Marker is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Marker
- plotly.graph_objs.histogram.selected.Marker
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Marker is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Marker
- plotly.graph_objs.histogram.selected.Marker
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Marker is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Marker
- plotly.graph_objs.histogram.selected.Marker
- etc.
""",
DeprecationWarning,
)
super(Marker, self).__init__(*args, **kwargs)
class RadialAxis(dict):
"""
plotly.graph_objs.RadialAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.RadialAxis
- plotly.graph_objs.layout.polar.RadialAxis
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.RadialAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.RadialAxis
- plotly.graph_objs.layout.polar.RadialAxis
"""
warnings.warn(
"""plotly.graph_objs.RadialAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.RadialAxis
- plotly.graph_objs.layout.polar.RadialAxis
""",
DeprecationWarning,
)
super(RadialAxis, self).__init__(*args, **kwargs)
class Scene(dict):
"""
plotly.graph_objs.Scene is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Scene
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Scene is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Scene
"""
warnings.warn(
"""plotly.graph_objs.Scene is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.Scene
""",
DeprecationWarning,
)
super(Scene, self).__init__(*args, **kwargs)
class Stream(dict):
"""
plotly.graph_objs.Stream is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Stream
- plotly.graph_objs.area.Stream
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Stream is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Stream
- plotly.graph_objs.area.Stream
"""
warnings.warn(
"""plotly.graph_objs.Stream is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.scatter.Stream
- plotly.graph_objs.area.Stream
""",
DeprecationWarning,
)
super(Stream, self).__init__(*args, **kwargs)
class XAxis(dict):
"""
plotly.graph_objs.XAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.XAxis
- plotly.graph_objs.layout.scene.XAxis
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.XAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.XAxis
- plotly.graph_objs.layout.scene.XAxis
"""
warnings.warn(
"""plotly.graph_objs.XAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.XAxis
- plotly.graph_objs.layout.scene.XAxis
""",
DeprecationWarning,
)
super(XAxis, self).__init__(*args, **kwargs)
class YAxis(dict):
"""
plotly.graph_objs.YAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.YAxis
- plotly.graph_objs.layout.scene.YAxis
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.YAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.YAxis
- plotly.graph_objs.layout.scene.YAxis
"""
warnings.warn(
"""plotly.graph_objs.YAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.YAxis
- plotly.graph_objs.layout.scene.YAxis
""",
DeprecationWarning,
)
super(YAxis, self).__init__(*args, **kwargs)
class ZAxis(dict):
"""
plotly.graph_objs.ZAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.scene.ZAxis
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.ZAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.scene.ZAxis
"""
warnings.warn(
"""plotly.graph_objs.ZAxis is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.layout.scene.ZAxis
""",
DeprecationWarning,
)
super(ZAxis, self).__init__(*args, **kwargs)
class XBins(dict):
"""
plotly.graph_objs.XBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.XBins
- plotly.graph_objs.histogram2d.XBins
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.XBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.XBins
- plotly.graph_objs.histogram2d.XBins
"""
warnings.warn(
"""plotly.graph_objs.XBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.XBins
- plotly.graph_objs.histogram2d.XBins
""",
DeprecationWarning,
)
super(XBins, self).__init__(*args, **kwargs)
class YBins(dict):
"""
plotly.graph_objs.YBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.YBins
- plotly.graph_objs.histogram2d.YBins
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.YBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.YBins
- plotly.graph_objs.histogram2d.YBins
"""
warnings.warn(
"""plotly.graph_objs.YBins is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.histogram.YBins
- plotly.graph_objs.histogram2d.YBins
""",
DeprecationWarning,
)
super(YBins, self).__init__(*args, **kwargs)
class Trace(dict):
"""
plotly.graph_objs.Trace is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Trace is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
"""
warnings.warn(
"""plotly.graph_objs.Trace is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Scatter
- plotly.graph_objs.Bar
- plotly.graph_objs.Area
- plotly.graph_objs.Histogram
- etc.
""",
DeprecationWarning,
)
super(Trace, self).__init__(*args, **kwargs)
class Histogram2dcontour(dict):
"""
plotly.graph_objs.Histogram2dcontour is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Histogram2dContour
"""
def __init__(self, *args, **kwargs):
"""
plotly.graph_objs.Histogram2dcontour is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Histogram2dContour
"""
warnings.warn(
"""plotly.graph_objs.Histogram2dcontour is deprecated.
Please replace it with one of the following more specific types
- plotly.graph_objs.Histogram2dContour
""",
DeprecationWarning,
)
super(Histogram2dcontour, self).__init__(*args, **kwargs)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,267 @@
from plotly.basedatatypes import BaseFrameHierarchyType as _BaseFrameHierarchyType
import copy as _copy
class Frame(_BaseFrameHierarchyType):
# class properties
# --------------------
_parent_path_str = ""
_path_str = "frame"
_valid_props = {"baseframe", "data", "group", "layout", "name", "traces"}
# baseframe
# ---------
@property
def baseframe(self):
"""
The name of the frame into which this frame's properties are
merged before applying. This is used to unify properties and
avoid needing to specify the same values for the same
properties in multiple frames.
The 'baseframe' property is a string and must be specified as:
- A string
- A number that will be converted to a string
Returns
-------
str
"""
return self["baseframe"]
@baseframe.setter
def baseframe(self, val):
self["baseframe"] = val
# data
# ----
@property
def data(self):
"""
A list of traces this frame modifies. The format is identical
to the normal trace definition.
Returns
-------
Any
"""
return self["data"]
@data.setter
def data(self, val):
self["data"] = val
# group
# -----
@property
def group(self):
"""
An identifier that specifies the group to which the frame
belongs, used by animate to select a subset of frames.
The 'group' property is a string and must be specified as:
- A string
- A number that will be converted to a string
Returns
-------
str
"""
return self["group"]
@group.setter
def group(self, val):
self["group"] = val
# layout
# ------
@property
def layout(self):
"""
Layout properties which this frame modifies. The format is
identical to the normal layout definition.
Returns
-------
Any
"""
return self["layout"]
@layout.setter
def layout(self, val):
self["layout"] = val
# name
# ----
@property
def name(self):
"""
A label by which to identify the frame
The 'name' property is a string and must be specified as:
- A string
- A number that will be converted to a string
Returns
-------
str
"""
return self["name"]
@name.setter
def name(self, val):
self["name"] = val
# traces
# ------
@property
def traces(self):
"""
A list of trace indices that identify the respective traces in
the data attribute
The 'traces' property accepts values of any type
Returns
-------
Any
"""
return self["traces"]
@traces.setter
def traces(self, val):
self["traces"] = val
# Self properties description
# ---------------------------
@property
def _prop_descriptions(self):
return """\
baseframe
The name of the frame into which this frame's
properties are merged before applying. This is used to
unify properties and avoid needing to specify the same
values for the same properties in multiple frames.
data
A list of traces this frame modifies. The format is
identical to the normal trace definition.
group
An identifier that specifies the group to which the
frame belongs, used by animate to select a subset of
frames.
layout
Layout properties which this frame modifies. The format
is identical to the normal layout definition.
name
A label by which to identify the frame
traces
A list of trace indices that identify the respective
traces in the data attribute
"""
def __init__(
self,
arg=None,
baseframe=None,
data=None,
group=None,
layout=None,
name=None,
traces=None,
**kwargs,
):
"""
Construct a new Frame object
Parameters
----------
arg
dict of properties compatible with this constructor or
an instance of :class:`plotly.graph_objs.Frame`
baseframe
The name of the frame into which this frame's
properties are merged before applying. This is used to
unify properties and avoid needing to specify the same
values for the same properties in multiple frames.
data
A list of traces this frame modifies. The format is
identical to the normal trace definition.
group
An identifier that specifies the group to which the
frame belongs, used by animate to select a subset of
frames.
layout
Layout properties which this frame modifies. The format
is identical to the normal layout definition.
name
A label by which to identify the frame
traces
A list of trace indices that identify the respective
traces in the data attribute
Returns
-------
Frame
"""
super(Frame, self).__init__("frames")
if "_parent" in kwargs:
self._parent = kwargs["_parent"]
return
# Validate arg
# ------------
if arg is None:
arg = {}
elif isinstance(arg, self.__class__):
arg = arg.to_plotly_json()
elif isinstance(arg, dict):
arg = _copy.copy(arg)
else:
raise ValueError(
"""\
The first argument to the plotly.graph_objs.Frame
constructor must be a dict or
an instance of :class:`plotly.graph_objs.Frame`"""
)
# Handle skip_invalid
# -------------------
self._skip_invalid = kwargs.pop("skip_invalid", False)
self._validate = kwargs.pop("_validate", True)
# Populate data dict with properties
# ----------------------------------
_v = arg.pop("baseframe", None)
_v = baseframe if baseframe is not None else _v
if _v is not None:
self["baseframe"] = _v
_v = arg.pop("data", None)
_v = data if data is not None else _v
if _v is not None:
self["data"] = _v
_v = arg.pop("group", None)
_v = group if group is not None else _v
if _v is not None:
self["group"] = _v
_v = arg.pop("layout", None)
_v = layout if layout is not None else _v
if _v is not None:
self["layout"] = _v
_v = arg.pop("name", None)
_v = name if name is not None else _v
if _v is not None:
self["name"] = _v
_v = arg.pop("traces", None)
_v = traces if traces is not None else _v
if _v is not None:
self["traces"] = _v
# Process unknown kwargs
# ----------------------
self._process_kwargs(**dict(arg, **kwargs))
# Reset skip_invalid
# ------------------
self._skip_invalid = False
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More