first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@ -0,0 +1,112 @@
"""
`plotly.express` is a terse, consistent, high-level wrapper around `plotly.graph_objects`
for rapid data exploration and figure generation. Learn more at https://plotly.express/
"""
from __future__ import absolute_import
from plotly import optional_imports
pd = optional_imports.get_module("pandas")
if pd is None:
raise ImportError(
"""\
Plotly express requires pandas to be installed."""
)
from ._imshow import imshow
from ._chart_types import ( # noqa: F401
scatter,
scatter_3d,
scatter_polar,
scatter_ternary,
scatter_mapbox,
scatter_geo,
line,
line_3d,
line_polar,
line_ternary,
line_mapbox,
line_geo,
area,
bar,
timeline,
bar_polar,
violin,
box,
strip,
histogram,
ecdf,
scatter_matrix,
parallel_coordinates,
parallel_categories,
choropleth,
density_contour,
density_heatmap,
pie,
sunburst,
treemap,
icicle,
funnel,
funnel_area,
choropleth_mapbox,
density_mapbox,
)
from ._core import ( # noqa: F401
set_mapbox_access_token,
defaults,
get_trendline_results,
NO_COLOR,
)
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
from . import data, colors, trendline_functions # noqa: F401
__all__ = [
"scatter",
"scatter_3d",
"scatter_polar",
"scatter_ternary",
"scatter_mapbox",
"scatter_geo",
"scatter_matrix",
"density_contour",
"density_heatmap",
"density_mapbox",
"line",
"line_3d",
"line_polar",
"line_ternary",
"line_mapbox",
"line_geo",
"parallel_coordinates",
"parallel_categories",
"area",
"bar",
"timeline",
"bar_polar",
"violin",
"box",
"strip",
"histogram",
"ecdf",
"choropleth",
"choropleth_mapbox",
"pie",
"sunburst",
"treemap",
"icicle",
"funnel",
"funnel_area",
"imshow",
"data",
"colors",
"trendline_functions",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
"Range",
"NO_COLOR",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,625 @@
import inspect
from textwrap import TextWrapper
try:
getfullargspec = inspect.getfullargspec
except AttributeError: # python 2
getfullargspec = inspect.getargspec
colref_type = "str or int or Series or array-like"
colref_desc = "Either a name of a column in `data_frame`, or a pandas Series or array_like object."
colref_list_type = "list of str or int, or Series or array-like"
colref_list_desc = (
"Either names of columns in `data_frame`, or pandas Series, or array_like objects"
)
docs = dict(
data_frame=[
"DataFrame or array-like or dict",
"This argument needs to be passed for column names (and not keyword names) to be used.",
"Array-like and dict are tranformed internally to a pandas DataFrame.",
"Optional: if missing, a DataFrame gets constructed under the hood using the other arguments.",
],
x=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
y=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.",
],
z=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.",
],
x_start=[
colref_type,
colref_desc,
"(required)",
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
x_end=[
colref_type,
colref_desc,
"(required)",
"Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
],
a=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the a axis in ternary coordinates.",
],
b=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the b axis in ternary coordinates.",
],
c=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the c axis in ternary coordinates.",
],
r=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the radial axis in polar coordinates.",
],
theta=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks along the angular axis in polar coordinates.",
],
values=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set values associated to sectors.",
],
parents=[
colref_type,
colref_desc,
"Values from this column or array_like are used as parents in sunburst and treemap charts.",
],
ids=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set ids of sectors",
],
path=[
colref_list_type,
colref_list_desc,
"List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.",
"An error is raised if path AND ids or parents is passed",
],
lat=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks according to latitude on a map.",
],
lon=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position marks according to longitude on a map.",
],
locations=[
colref_type,
colref_desc,
"Values from this column or array_like are to be interpreted according to `locationmode` and mapped to longitude/latitude.",
],
base=[
colref_type,
colref_desc,
"Values from this column or array_like are used to position the base of the bar.",
],
dimensions=[
colref_list_type,
colref_list_desc,
"Values from these columns are used for multidimensional visualization.",
],
dimensions_max_cardinality=[
"int (default 50)",
"When `dimensions` is `None` and `data_frame` is provided, "
"columns with more than this number of unique values are excluded from the output.",
"Not used when `dimensions` is passed.",
],
error_x=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size x-axis error bars.",
"If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.",
],
error_x_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size x-axis error bars in the negative direction.",
"Ignored if `error_x` is `None`.",
],
error_y=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size y-axis error bars.",
"If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.",
],
error_y_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size y-axis error bars in the negative direction.",
"Ignored if `error_y` is `None`.",
],
error_z=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size z-axis error bars.",
"If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.",
],
error_z_minus=[
colref_type,
colref_desc,
"Values from this column or array_like are used to size z-axis error bars in the negative direction.",
"Ignored if `error_z` is `None`.",
],
color=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign color to marks.",
],
opacity=["float", "Value between 0 and 1. Sets the opacity for markers."],
line_dash=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign dash-patterns to lines.",
],
line_group=[
colref_type,
colref_desc,
"Values from this column or array_like are used to group rows of `data_frame` into lines.",
],
symbol=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign symbols to marks.",
],
pattern_shape=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign pattern shapes to marks.",
],
size=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign mark sizes.",
],
radius=["int (default is 30)", "Sets the radius of influence of each point."],
hover_name=[
colref_type,
colref_desc,
"Values from this column or array_like appear in bold in the hover tooltip.",
],
hover_data=[
"list of str or int, or Series or array-like, or dict",
"Either a list of names of columns in `data_frame`, or pandas Series,",
"or array_like objects",
"or a dict with column names as keys, with values True (for default formatting)",
"False (in order to remove this column from hover information),",
"or a formatting string, for example ':.3f' or '|%a'",
"or list-like data to appear in the hover tooltip",
"or tuples with a bool or formatting string as first element,",
"and list-like data to appear in hover as second element",
"Values from these columns appear as extra data in the hover tooltip.",
],
custom_data=[
colref_list_type,
colref_list_desc,
"Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
],
text=[
colref_type,
colref_desc,
"Values from this column or array_like appear in the figure as text labels.",
],
names=[
colref_type,
colref_desc,
"Values from this column or array_like are used as labels for sectors.",
],
locationmode=[
"str",
"One of 'ISO-3', 'USA-states', or 'country names'",
"Determines the set of locations used to match entries in `locations` to regions on the map.",
],
facet_row=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.",
],
facet_col=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
],
facet_col_wrap=[
"int",
"Maximum number of facet columns.",
"Wraps the column variable at this width, so that the column facets span multiple rows.",
"Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
],
facet_row_spacing=[
"float between 0 and 1",
"Spacing between facet rows, in paper units. Default is 0.03 or 0.0.7 when facet_col_wrap is used.",
],
facet_col_spacing=[
"float between 0 and 1",
"Spacing between facet columns, in paper units Default is 0.02.",
],
animation_frame=[
colref_type,
colref_desc,
"Values from this column or array_like are used to assign marks to animation frames.",
],
animation_group=[
colref_type,
colref_desc,
"Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.",
],
symbol_sequence=[
"list of str",
"Strings should define valid plotly.js symbols.",
"When `symbol` is set, values in that column are assigned symbols by cycling through `symbol_sequence` in the order described in `category_orders`, unless the value of `symbol` is a key in `symbol_map`.",
],
symbol_map=[
"dict with str keys and str values (default `{}`)",
"String values should define plotly.js symbols",
"Used to override `symbol_sequence` to assign a specific symbols to marks corresponding with specific values.",
"Keys in `symbol_map` should be values in the column denoted by `symbol`.",
"Alternatively, if the values of `symbol` are valid symbol names, the string `'identity'` may be passed to cause them to be used directly.",
],
line_dash_map=[
"dict with str keys and str values (default `{}`)",
"Strings values define plotly.js dash-patterns.",
"Used to override `line_dash_sequences` to assign a specific dash-patterns to lines corresponding with specific values.",
"Keys in `line_dash_map` should be values in the column denoted by `line_dash`.",
"Alternatively, if the values of `line_dash` are valid line-dash names, the string `'identity'` may be passed to cause them to be used directly.",
],
line_dash_sequence=[
"list of str",
"Strings should define valid plotly.js dash-patterns.",
"When `line_dash` is set, values in that column are assigned dash-patterns by cycling through `line_dash_sequence` in the order described in `category_orders`, unless the value of `line_dash` is a key in `line_dash_map`.",
],
pattern_shape_map=[
"dict with str keys and str values (default `{}`)",
"Strings values define plotly.js patterns-shapes.",
"Used to override `pattern_shape_sequences` to assign a specific patterns-shapes to lines corresponding with specific values.",
"Keys in `pattern_shape_map` should be values in the column denoted by `pattern_shape`.",
"Alternatively, if the values of `pattern_shape` are valid patterns-shapes names, the string `'identity'` may be passed to cause them to be used directly.",
],
pattern_shape_sequence=[
"list of str",
"Strings should define valid plotly.js patterns-shapes.",
"When `pattern_shape` is set, values in that column are assigned patterns-shapes by cycling through `pattern_shape_sequence` in the order described in `category_orders`, unless the value of `pattern_shape` is a key in `pattern_shape_map`.",
],
color_discrete_sequence=[
"list of str",
"Strings should define valid CSS-colors.",
"When `color` is set and the values in the corresponding column are not numeric, values in that column are assigned colors by cycling through `color_discrete_sequence` in the order described in `category_orders`, unless the value of `color` is a key in `color_discrete_map`.",
"Various useful color sequences are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.qualitative`.",
],
color_discrete_map=[
"dict with str keys and str values (default `{}`)",
"String values should define valid CSS-colors",
"Used to override `color_discrete_sequence` to assign a specific colors to marks corresponding with specific values.",
"Keys in `color_discrete_map` should be values in the column denoted by `color`.",
"Alternatively, if the values of `color` are valid colors, the string `'identity'` may be passed to cause them to be used directly.",
],
color_continuous_scale=[
"list of str",
"Strings should define valid CSS-colors",
"This list is used to build a continuous color scale when the column denoted by `color` contains numeric data.",
"Various useful color scales are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.sequential`, `plotly.express.colors.diverging` and `plotly.express.colors.cyclical`.",
],
color_continuous_midpoint=[
"number (default `None`)",
"If set, computes the bounds of the continuous color scale to have the desired midpoint.",
"Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.",
],
size_max=["int (default `20`)", "Set the maximum mark size when using `size`."],
markers=["boolean (default `False`)", "If `True`, markers are shown on lines."],
lines=[
"boolean (default `True`)",
"If `False`, lines are not drawn (forced to `True` if `markers` is `False`).",
],
log_x=[
"boolean (default `False`)",
"If `True`, the x-axis is log-scaled in cartesian coordinates.",
],
log_y=[
"boolean (default `False`)",
"If `True`, the y-axis is log-scaled in cartesian coordinates.",
],
log_z=[
"boolean (default `False`)",
"If `True`, the z-axis is log-scaled in cartesian coordinates.",
],
log_r=[
"boolean (default `False`)",
"If `True`, the radial axis is log-scaled in polar coordinates.",
],
range_x=[
"list of two numbers",
"If provided, overrides auto-scaling on the x-axis in cartesian coordinates.",
],
range_y=[
"list of two numbers",
"If provided, overrides auto-scaling on the y-axis in cartesian coordinates.",
],
range_z=[
"list of two numbers",
"If provided, overrides auto-scaling on the z-axis in cartesian coordinates.",
],
range_color=[
"list of two numbers",
"If provided, overrides auto-scaling on the continuous color scale.",
],
range_r=[
"list of two numbers",
"If provided, overrides auto-scaling on the radial axis in polar coordinates.",
],
range_theta=[
"list of two numbers",
"If provided, overrides auto-scaling on the angular axis in polar coordinates.",
],
title=["str", "The figure title."],
template=[
"str or dict or plotly.graph_objects.layout.Template instance",
"The figure template name (must be a key in plotly.io.templates) or definition.",
],
width=["int (default `None`)", "The figure width in pixels."],
height=["int (default `None`)", "The figure height in pixels."],
labels=[
"dict with str keys and str values (default `{}`)",
"By default, column names are used in the figure for axis titles, legend entries and hovers.",
"This parameter allows this to be overridden.",
"The keys of this dict should correspond to column names, and the values should correspond to the desired label to be displayed.",
],
category_orders=[
"dict with str keys and list of str values (default `{}`)",
"By default, in Python 3.6+, the order of categorical values in axes, legends and facets depends on the order in which these values are first encountered in `data_frame` (and no order is guaranteed by default in Python below 3.6).",
"This parameter is used to force a specific ordering of values per column.",
"The keys of this dict should correspond to column names, and the values should be lists of strings corresponding to the specific display order desired.",
],
marginal=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a subplot is drawn alongside the main plot, visualizing the distribution.",
],
marginal_x=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a horizontal subplot is drawn above the main plot, visualizing the x-distribution.",
],
marginal_y=[
"str",
"One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.",
"If set, a vertical subplot is drawn to the right of the main plot, visualizing the y-distribution.",
],
trendline=[
"str",
"One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.",
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
"If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.",
"If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.",
"If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.",
"See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how",
"to configure them with the `trendline_options` argument.",
],
trendline_options=[
"dict",
"Options passed as the first argument to the function from `plotly.express.trendline_functions` ",
"named in the `trendline` argument.",
],
trendline_color_override=[
"str",
"Valid CSS color.",
"If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.",
],
trendline_scope=[
"str (one of `'trace'` or `'overall'`, default `'trace'`)",
"If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.",
],
render_mode=[
"str",
"One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`",
"Controls the browser API used to draw marks.",
"`'svg`' is appropriate for figures of less than 1000 data points, and will allow for fully-vectorized output.",
"`'webgl'` is likely necessary for acceptable performance above 1000 points but rasterizes part of the output. ",
"`'auto'` uses heuristics to choose the mode.",
],
direction=[
"str",
"One of '`counterclockwise'` or `'clockwise'`. Default is `'clockwise'`",
"Sets the direction in which increasing values of the angular axis are drawn.",
],
start_angle=[
"int (default `90`)",
"Sets start angle for the angular axis, with 0 being due east and 90 being due north.",
],
histfunc=[
"str (default `'count'` if no arguments are provided, else `'sum'`)",
"One of `'count'`, `'sum'`, `'avg'`, `'min'`, or `'max'`."
"Function used to aggregate values for summarization (note: can be normalized with `histnorm`).",
],
histnorm=[
"str (default `None`)",
"One of `'percent'`, `'probability'`, `'density'`, or `'probability density'`",
"If `None`, the output of `histfunc` is used as is.",
"If `'probability'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins.",
"If `'percent'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins and multiplied by 100.",
"If `'density'`, the output of `histfunc` for a given bin is divided by the size of the bin.",
"If `'probability density'`, the output of `histfunc` for a given bin is normalized such that it corresponds to the probability that a random event whose distribution is described by the output of `histfunc` will fall into that bin.",
],
barnorm=[
"str (default `None`)",
"One of `'fraction'` or `'percent'`.",
"If `'fraction'`, the value of each bar is divided by the sum of all values at that location coordinate.",
"`'percent'` is the same but multiplied by 100 to show percentages.",
"`None` will stack up all values at each location coordinate.",
],
groupnorm=[
"str (default `None`)",
"One of `'fraction'` or `'percent'`.",
"If `'fraction'`, the value of each point is divided by the sum of all values at that location coordinate.",
"`'percent'` is the same but multiplied by 100 to show percentages.",
"`None` will stack up all values at each location coordinate.",
],
barmode=[
"str (default `'relative'`)",
"One of `'group'`, `'overlay'` or `'relative'`",
"In `'relative'` mode, bars are stacked above zero for positive values and below zero for negative values.",
"In `'overlay'` mode, bars are drawn on top of one another.",
"In `'group'` mode, bars are placed beside each other.",
],
boxmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, boxes are on drawn top of one another.",
"In `'group'` mode, boxes are placed beside each other.",
],
violinmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, violins are on drawn top of one another.",
"In `'group'` mode, violins are placed beside each other.",
],
stripmode=[
"str (default `'group'`)",
"One of `'group'` or `'overlay'`",
"In `'overlay'` mode, strips are on drawn top of one another.",
"In `'group'` mode, strips are placed beside each other.",
],
zoom=["int (default `8`)", "Between 0 and 20.", "Sets map zoom level."],
orientation=[
"str, one of `'h'` for horizontal or `'v'` for vertical. ",
"(default `'v'` if `x` and `y` are provided and both continous or both categorical, ",
"otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and `y`(`x`) is continuous, ",
"otherwise `'v'`(`'h'`) if only `x`(`y`) is provided) ",
],
line_close=[
"boolean (default `False`)",
"If `True`, an extra line segment is drawn between the first and last point.",
],
line_shape=["str (default `'linear'`)", "One of `'linear'` or `'spline'`."],
fitbounds=["str (default `False`).", "One of `False`, `locations` or `geojson`."],
basemap_visible=["bool", "Force the basemap visibility."],
scope=[
"str (default `'world'`).",
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`"
"Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`.",
],
projection=[
"str ",
"One of `'equirectangular'`, `'mercator'`, `'orthographic'`, `'natural earth'`, `'kavrayskiy7'`, `'miller'`, `'robinson'`, `'eckert4'`, `'azimuthal equal area'`, `'azimuthal equidistant'`, `'conic equal area'`, `'conic conformal'`, `'conic equidistant'`, `'gnomonic'`, `'stereographic'`, `'mollweide'`, `'hammer'`, `'transverse mercator'`, `'albers usa'`, `'winkel tripel'`, `'aitoff'`, or `'sinusoidal'`"
"Default depends on `scope`.",
],
center=[
"dict",
"Dict keys are `'lat'` and `'lon'`",
"Sets the center point of the map.",
],
mapbox_style=[
"str (default `'basic'`, needs Mapbox API token)",
"Identifier of base map style, some of which require a Mapbox API token to be set using `plotly.express.set_mapbox_access_token()`.",
"Allowed values which do not require a Mapbox API token are `'open-street-map'`, `'white-bg'`, `'carto-positron'`, `'carto-darkmatter'`, `'stamen-terrain'`, `'stamen-toner'`, `'stamen-watercolor'`.",
"Allowed values which do require a Mapbox API token are `'basic'`, `'streets'`, `'outdoors'`, `'light'`, `'dark'`, `'satellite'`, `'satellite-streets'`.",
],
points=[
"str or boolean (default `'outliers'`)",
"One of `'outliers'`, `'suspectedoutliers'`, `'all'`, or `False`.",
"If `'outliers'`, only the sample points lying outside the whiskers are shown.",
"If `'suspectedoutliers'`, all outlier points are shown and those less than 4*Q1-3*Q3 or greater than 4*Q3-3*Q1 are highlighted with the marker's `'outliercolor'`.",
"If `'outliers'`, only the sample points lying outside the whiskers are shown.",
"If `'all'`, all sample points are shown.",
"If `False`, no sample points are shown and the whiskers extend to the full range of the sample.",
],
box=["boolean (default `False`)", "If `True`, boxes are drawn inside the violins."],
notched=["boolean (default `False`)", "If `True`, boxes are drawn with notches."],
geojson=[
"GeoJSON-formatted dict",
"Must contain a Polygon feature collection, with IDs, which are references from `locations`.",
],
featureidkey=[
"str (default: `'id'`)",
"Path to field in GeoJSON feature object with which to match the values passed in to `locations`."
"The most common alternative to the default is of the form `'properties.<key>`.",
],
cumulative=[
"boolean (default `False`)",
"If `True`, histogram values are cumulative.",
],
nbins=["int", "Positive integer.", "Sets the number of bins."],
nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."],
nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."],
branchvalues=[
"str",
"'total' or 'remainder'",
"Determines how the items in `values` are summed. When"
"set to 'total', items in `values` are taken to be value"
"of all its descendants. When set to 'remainder', items"
"in `values` corresponding to the root and the branches"
":sectors are taken to be the extra part not part of the"
"sum of the values at their leaves.",
],
maxdepth=[
"int",
"Positive integer",
"Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
"levels in the hierarchy.",
],
ecdfnorm=[
"string or `None` (default `'probability'`)",
"One of `'probability'` or `'percent'`",
"If `None`, values will be raw counts or sums.",
"If `'probability', values will be probabilities normalized from 0 to 1.",
"If `'percent', values will be percentages normalized from 0 to 100.",
],
ecdfmode=[
"string (default `'standard'`)",
"One of `'standard'`, `'complementary'` or `'reversed'`",
"If `'standard'`, the ECDF is plotted such that values represent data at or below the point.",
"If `'complementary'`, the CCDF is plotted such that values represent data above the point.",
"If `'reversed'`, a variant of the CCDF is plotted such that values represent data at or above the point.",
],
text_auto=[
"bool or string (default `False`)",
"If `True` or a string, the x or y or z values will be displayed as text, depending on the orientation",
"A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.",
],
)
def make_docstring(fn, override_dict=None, append_dict=None):
override_dict = {} if override_dict is None else override_dict
append_dict = {} if append_dict is None else append_dict
tw = TextWrapper(width=75, initial_indent=" ", subsequent_indent=" ")
result = (fn.__doc__ or "") + "\nParameters\n----------\n"
for param in getfullargspec(fn)[0]:
if override_dict.get(param):
param_doc = list(override_dict[param])
else:
param_doc = list(docs[param])
if append_dict.get(param):
param_doc += append_dict[param]
param_desc_list = param_doc[1:]
param_desc = (
tw.fill(" ".join(param_desc_list or ""))
if param in docs or param in override_dict
else "(documentation missing from map)"
)
param_type = param_doc[0]
result += "%s: %s\n%s\n" % (param, param_type, param_desc)
result += "\nReturns\n-------\n"
result += " plotly.graph_objects.Figure"
return result

View File

@ -0,0 +1,600 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade, init_figure, configure_animation_controls
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
import numpy as np
import itertools
from plotly.utils import image_array_to_data_uri
try:
import xarray
xarray_imported = True
except ImportError:
xarray_imported = False
_float_types = []
def _vectorize_zvalue(z, mode="max"):
alpha = 255 if mode == "max" else 0
if z is None:
return z
elif np.isscalar(z):
return [z] * 3 + [alpha]
elif len(z) == 1:
return list(z) * 3 + [alpha]
elif len(z) == 3:
return list(z) + [alpha]
elif len(z) == 4:
return z
else:
raise ValueError(
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
"A value of %s was passed for zmax." % str(z)
)
def _infer_zmax_from_type(img):
dt = img.dtype.type
rtol = 1.05
if dt in _integer_types:
return _integer_ranges[dt][1]
else:
im_max = img[np.isfinite(img)].max()
if im_max <= 1 * rtol:
return 1
elif im_max <= 255 * rtol:
return 255
elif im_max <= 65535 * rtol:
return 65535
else:
return 2**32
def imshow(
img,
zmin=None,
zmax=None,
origin=None,
labels={},
x=None,
y=None,
animation_frame=None,
facet_col=None,
facet_col_wrap=None,
facet_col_spacing=None,
facet_row_spacing=None,
color_continuous_scale=None,
color_continuous_midpoint=None,
range_color=None,
title=None,
template=None,
width=None,
height=None,
aspect=None,
contrast_rescaling=None,
binary_string=None,
binary_backend="auto",
binary_compression_level=4,
binary_format="png",
text_auto=False,
) -> go.Figure:
"""
Display an image, i.e. data on a 2D regular raster.
Parameters
----------
img: array-like image, or xarray
The image data. Supported array shapes are
- (M, N): an image with scalar data. The data is visualized
using a colormap.
- (M, N, 3): an image with RGB values.
- (M, N, 4): an image with RGBA values, i.e. including transparency.
zmin, zmax : scalar or iterable, optional
zmin and zmax define the scalar range that the colormap covers. By default,
zmin and zmax correspond to the min and max values of the datatype for integer
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
a multichannel image of floats, the max of the image is computed and zmax is the
smallest power of 256 (1, 255, 65535) greater than this max value,
with a 5% tolerance. For a single-channel image, the max of the image is used.
Overridden by range_color.
origin : str, 'upper' or 'lower' (default 'upper')
position of the [0, 0] pixel of the image array, in the upper left or lower left
corner. The convention 'upper' is typically used for matrices and images.
labels : dict with str keys and str values (default `{}`)
Sets names used in the figure for axis titles (keys ``x`` and ``y``),
colorbar title and hoverlabel (key ``color``). The values should correspond
to the desired label to be displayed. If ``img`` is an xarray, dimension
names are used for axis titles, and long name for the colorbar title
(unless overridden in ``labels``). Possible keys are: x, y, and color.
x, y: list-like, optional
x and y are used to label the axes of single-channel heatmap visualizations and
their lengths must match the lengths of the second and first dimensions of the
img argument. They are auto-populated if the input is an xarray.
animation_frame: int or str, optional (default None)
axis number along which the image array is sliced to create an animation plot.
If `img` is an xarray, `animation_frame` can be the name of one the dimensions.
facet_col: int or str, optional (default None)
axis number along which the image array is sliced to create a facetted plot.
If `img` is an xarray, `facet_col` can be the name of one the dimensions.
facet_col_wrap: int
Maximum number of facet columns. Wraps the column variable at this width,
so that the column facets span multiple rows.
Ignored if `facet_col` is None.
facet_col_spacing: float between 0 and 1
Spacing between facet columns, in paper units. Default is 0.02.
facet_row_spacing: float between 0 and 1
Spacing between facet rows created when ``facet_col_wrap`` is used, in
paper units. Default is 0.0.7.
color_continuous_scale : str or list of str
colormap used to map scalar data to colors (for a 2D image). This parameter is
not used for RGB or RGBA images. If a string is provided, it should be the name
of a known color scale, and if a list is provided, it should be a list of CSS-
compatible colors.
color_continuous_midpoint : number
If set, computes the bounds of the continuous color scale to have the desired
midpoint. Overridden by range_color or zmin and zmax.
range_color : list of two numbers
If provided, overrides auto-scaling on the continuous color scale, including
overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only
for single-channel images.
title : str
The figure title.
template : str or dict or plotly.graph_objects.layout.Template instance
The figure template name or definition.
width : number
The figure width in pixels.
height: number
The figure height in pixels.
aspect: 'equal', 'auto', or None
- 'equal': Ensures an aspect ratio of 1 or pixels (square pixels)
- 'auto': The axes is kept fixed and the aspect ratio of pixels is
adjusted so that the data fit in the axes. In general, this will
result in non-square pixels.
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
(which have typically heterogeneous coordinates)
contrast_rescaling: 'minmax', 'infer', or None
how to determine data values corresponding to the bounds of the color
range, when zmin or zmax are not passed. If `minmax`, the min and max
values of the image are used. If `infer`, a heuristic based on the image
data type is used.
binary_string: bool, default None
if True, the image data are first rescaled and encoded as uint8 and
then passed to plotly.js as a b64 PNG string. If False, data are passed
unchanged as a numerical array. Setting to True may lead to performance
gains, at the cost of a loss of precision depending on the original data
type. If None, use_binary_string is set to True for multichannel (eg) RGB
arrays, and to False for single-channel (2D) arrays. 2D arrays are
represented as grayscale and with no colorbar if use_binary_string is
True.
binary_backend: str, 'auto' (default), 'pil' or 'pypng'
Third-party package for the transformation of numpy arrays to
png b64 strings. If 'auto', Pillow is used if installed, otherwise
pypng.
binary_compression_level: int, between 0 and 9 (default 4)
png compression level to be passed to the backend when transforming an
array to a png b64 string. Increasing `binary_compression` decreases the
size of the png string, but the compression step takes more time. For most
images it is not worth using levels greater than 5, but it's possible to
test `len(fig.data[0].source)` and to time the execution of `imshow` to
tune the level of compression. 0 means no compression (not recommended).
binary_format: str, 'png' (default) or 'jpg'
compression format used to generate b64 string. 'png' is recommended
since it uses lossless compression, but 'jpg' (lossy) compression can
result if smaller binary strings for natural images.
text_auto: bool or str (default `False`)
If `True` or a string, single-channel `img` values will be displayed as text.
A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.
Returns
-------
fig : graph_objects.Figure containing the displayed image
See also
--------
plotly.graph_objects.Image : image trace
plotly.graph_objects.Heatmap : heatmap trace
Notes
-----
In order to update and customize the returned figure, use
`go.Figure.update_traces` or `go.Figure.update_layout`.
If an xarray is passed, dimensions names and coordinates are used for
axes labels and ticks.
"""
args = locals()
apply_default_cascade(args)
labels = labels.copy()
nslices_facet = 1
if facet_col is not None:
if isinstance(facet_col, str):
facet_col = img.dims.index(facet_col)
nslices_facet = img.shape[facet_col]
facet_slices = range(nslices_facet)
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
nrows = (
nslices_facet // ncols + 1
if nslices_facet % ncols
else nslices_facet // ncols
)
else:
nrows = 1
ncols = 1
if animation_frame is not None:
if isinstance(animation_frame, str):
animation_frame = img.dims.index(animation_frame)
nslices_animation = img.shape[animation_frame]
animation_slices = range(nslices_animation)
slice_dimensions = (facet_col is not None) + (
animation_frame is not None
) # 0, 1, or 2
facet_label = None
animation_label = None
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
dims = list(img.dims)
img_is_xarray = True
if facet_col is not None:
facet_slices = img.coords[img.dims[facet_col]].values
_ = dims.pop(facet_col)
facet_label = img.dims[facet_col]
if animation_frame is not None:
animation_slices = img.coords[img.dims[animation_frame]].values
_ = dims.pop(animation_frame)
animation_label = img.dims[animation_frame]
y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
if x is None:
x = img.coords[x_label].values
if y is None:
y = img.coords[y_label].values
if aspect is None:
aspect = "auto"
if labels.get("x", None) is None:
labels["x"] = x_label
if labels.get("y", None) is None:
labels["y"] = y_label
if labels.get("animation_frame", None) is None:
labels["animation_frame"] = animation_label
if labels.get("facet_col", None) is None:
labels["facet_col"] = facet_label
if labels.get("color", None) is None:
labels["color"] = xarray.plot.utils.label_from_attrs(img)
labels["color"] = labels["color"].replace("\n", "<br>")
else:
if hasattr(img, "columns") and hasattr(img.columns, "__len__"):
if x is None:
x = img.columns
if labels.get("x", None) is None and hasattr(img.columns, "name"):
labels["x"] = img.columns.name or ""
if hasattr(img, "index") and hasattr(img.index, "__len__"):
if y is None:
y = img.index
if labels.get("y", None) is None and hasattr(img.index, "name"):
labels["y"] = img.index.name or ""
if labels.get("x", None) is None:
labels["x"] = ""
if labels.get("y", None) is None:
labels["y"] = ""
if labels.get("color", None) is None:
labels["color"] = ""
if aspect is None:
aspect = "equal"
# --- Set the value of binary_string (forbidden for pandas)
if isinstance(img, pd.DataFrame):
if binary_string:
raise ValueError("Binary strings cannot be used with pandas arrays")
is_dataframe = True
else:
is_dataframe = False
# --------------- Starting from here img is always a numpy array --------
img = np.asanyarray(img)
# Reshape array so that animation dimension comes first, then facets, then images
if facet_col is not None:
img = np.moveaxis(img, facet_col, 0)
if animation_frame is not None and animation_frame < facet_col:
animation_frame += 1
facet_col = True
if animation_frame is not None:
img = np.moveaxis(img, animation_frame, 0)
animation_frame = True
args["animation_frame"] = (
"animation_frame"
if labels.get("animation_frame") is None
else labels["animation_frame"]
)
iterables = ()
if animation_frame is not None:
iterables += (range(nslices_animation),)
if facet_col is not None:
iterables += (range(nslices_facet),)
# Default behaviour of binary_string: True for RGB images, False for 2D
if binary_string is None:
binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe
# Cast bools to uint8 (also one byte)
if img.dtype == bool:
img = 255 * img.astype(np.uint8)
if range_color is not None:
zmin = range_color[0]
zmax = range_color[1]
# -------- Contrast rescaling: either minmax or infer ------------------
if contrast_rescaling is None:
contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"
# We try to set zmin and zmax only if necessary, because traces have good defaults
if contrast_rescaling == "minmax":
# When using binary_string and minmax we need to set zmin and zmax to rescale the image
if (zmin is not None or binary_string) and zmax is None:
zmax = img.max()
if (zmax is not None or binary_string) and zmin is None:
zmin = img.min()
else:
# For uint8 data and infer we let zmin and zmax to be None if passed as None
if zmax is None and img.dtype != np.uint8:
zmax = _infer_zmax_from_type(img)
if zmin is None and zmax is not None:
zmin = 0
# For 2d data, use Heatmap trace, unless binary_string is True
if img.ndim == 2 + slice_dimensions and not binary_string:
y_index = slice_dimensions
if y is not None and img.shape[y_index] != len(y):
raise ValueError(
"The length of the y vector must match the length of the first "
+ "dimension of the img matrix."
)
x_index = slice_dimensions + 1
if x is not None and img.shape[x_index] != len(x):
raise ValueError(
"The length of the x vector must match the length of the second "
+ "dimension of the img matrix."
)
texttemplate = None
if text_auto is True:
texttemplate = "%{z}"
elif text_auto is not False:
texttemplate = "%{z:" + text_auto + "}"
traces = [
go.Heatmap(
x=x,
y=y,
z=img[index_tup],
coloraxis="coloraxis1",
name=str(i),
texttemplate=texttemplate,
)
for i, index_tup in enumerate(itertools.product(*iterables))
]
autorange = True if origin == "lower" else "reversed"
layout = dict(yaxis=dict(autorange=autorange))
if aspect == "equal":
layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
layout["yaxis"]["constrain"] = "domain"
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
layout["coloraxis1"] = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=color_continuous_midpoint,
cmin=zmin,
cmax=zmax,
)
if labels["color"]:
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
# For 2D+RGB data, use Image trace
elif (
img.ndim >= 3
and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
) or (img.ndim == 2 and binary_string):
rescale_image = True # to check whether image has been modified
if zmin is not None and zmax is not None:
zmin, zmax = (
_vectorize_zvalue(zmin, mode="min"),
_vectorize_zvalue(zmax, mode="max"),
)
x0, y0, dx, dy = (None,) * 4
error_msg_xarray = (
"Non-numerical coordinates were passed with xarray `img`, but "
"the Image trace cannot handle it. Please use `binary_string=False` "
"for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
)
if x is not None:
x = np.asanyarray(x)
if np.issubdtype(x.dtype, np.number):
x0 = x[0]
dx = x[1] - x[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `x` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if y is not None:
y = np.asanyarray(y)
if np.issubdtype(y.dtype, np.number):
y0 = y[0]
dy = y[1] - y[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `y` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if binary_string:
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
rescale_image = False
elif img.ndim == 2 + slice_dimensions: # single-channel image
img_rescaled = rescale_intensity(
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
)
else:
img_rescaled = np.stack(
[
rescale_intensity(
img[..., ch],
in_range=(zmin[ch], zmax[ch]),
out_range=np.uint8,
)
for ch in range(img.shape[-1])
],
axis=-1,
)
img_str = [
image_array_to_data_uri(
img_rescaled[index_tup],
backend=binary_backend,
compression=binary_compression_level,
ext=binary_format,
)
for index_tup in itertools.product(*iterables)
]
traces = [
go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy)
for i, img_str_slice in enumerate(img_str)
]
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
traces = [
go.Image(
z=img[index_tup],
zmin=zmin,
zmax=zmax,
colormodel=colormodel,
x0=x0,
y0=y0,
dx=dx,
dy=dy,
)
for index_tup in itertools.product(*iterables)
]
layout = {}
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
if dx is not None and dx < 0:
layout["xaxis"] = dict(autorange="reversed")
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
"An image of shape %s was provided. "
"Alternatively, 3- or 4-D single or multichannel datasets can be "
"visualized using the `facet_col` or/and `animation_frame` arguments."
% str(img.shape)
)
# Now build figure
col_labels = []
if facet_col is not None:
slice_label = (
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
)
col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
for attr_name in ["height", "width"]:
if args[attr_name]:
layout[attr_name] = args[attr_name]
if args["title"]:
layout["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
layout["margin"] = {"t": 60}
frame_list = []
for index, trace in enumerate(traces):
if (facet_col and index < nrows * ncols) or index == 0:
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
if animation_frame is not None:
for i, index in zip(range(nslices_animation), animation_slices):
frame_list.append(
dict(
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
layout=layout,
name=str(index),
)
)
if animation_frame:
fig.frames = frame_list
fig.update_layout(layout)
# Hover name, z or color
if binary_string and rescale_image and not np.all(img == img_rescaled):
# we rescaled the image, hence z is not displayed in hover since it does
# not correspond to img values
hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % (
labels["x"] or "x",
labels["y"] or "y",
)
else:
if trace["type"] == "heatmap":
hover_name = "%{z}"
elif img.ndim == 2:
hover_name = "%{z[0]}"
elif img.ndim == 3 and img.shape[-1] == 3:
hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]"
else:
hover_name = "%{z}"
hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % (
labels["x"] or "x",
labels["y"] or "y",
labels["color"] or "color",
hover_name,
)
fig.update_traces(hovertemplate=hovertemplate)
if labels["x"]:
fig.update_xaxes(title_text=labels["x"], row=1)
if labels["y"]:
fig.update_yaxes(title_text=labels["y"], col=1)
configure_animation_controls(args, go.Image, fig)
fig.update_layout(template=args["template"], overwrite=True)
return fig

View File

@ -0,0 +1,40 @@
class IdentityMap(object):
"""
`dict`-like object which acts as if the value for any key is the key itself. Objects
of this class can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
functions, such as `line_dash_map` and `symbol_map`.
"""
def __getitem__(self, key):
return key
def __contains__(self, key):
return True
def copy(self):
return self
class Constant(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should take on a
constant value. An optional label can be provided.
"""
def __init__(self, value, label=None):
self.value = value
self.label = label
class Range(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should be mapped
onto integers starting at 0. An optional label can be provided.
"""
def __init__(self, label=None):
self.label = label

View File

@ -0,0 +1,52 @@
"""For a list of colors available in `plotly.express.colors`, please see
* the `tutorial on discrete color sequences <https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express>`_
* the `list of built-in continuous color scales <https://plotly.com/python/builtin-colorscales/>`_
* the `tutorial on continuous colors <https://plotly.com/python/colorscales/>`_
Color scales are available within the following namespaces
* cyclical
* diverging
* qualitative
* sequential
"""
from __future__ import absolute_import
from plotly.colors import *
__all__ = [
"named_colorscales",
"cyclical",
"diverging",
"sequential",
"qualitative",
"colorbrewer",
"colorbrewer",
"carto",
"cmocean",
"color_parser",
"colorscale_to_colors",
"colorscale_to_scale",
"convert_colors_to_same_type",
"convert_colorscale_to_rgb",
"convert_dict_colors_to_same_type",
"convert_to_RGB_255",
"find_intermediate_color",
"hex_to_rgb",
"label_rgb",
"make_colorscale",
"n_colors",
"unconvert_from_RGB_255",
"unlabel_rgb",
"validate_colors",
"validate_colors_dict",
"validate_colorscale",
"validate_scale_values",
"plotlyjs",
"DEFAULT_PLOTLY_COLORS",
"PLOTLY_SCALES",
"get_colorscale",
"sample_colorscale",
]

View File

@ -0,0 +1,19 @@
"""Built-in datasets for demonstration, educational and test purposes.
"""
from __future__ import absolute_import
from plotly.data import *
__all__ = [
"carshare",
"election",
"election_geojson",
"experiment",
"gapminder",
"iris",
"medals_wide",
"medals_long",
"stocks",
"tips",
"wind",
]

View File

@ -0,0 +1,248 @@
"""Vendored code from scikit-image in order to limit the number of dependencies
Extracted from scikit-image/skimage/exposure/exposure.py
"""
import numpy as np
from warnings import warn
_integer_types = (
np.byte,
np.ubyte, # 8 bits
np.short,
np.ushort, # 16 bits
np.intc,
np.uintc, # 16 or 32 or 64 bits
np.int_,
np.uint, # 32 or 64 bits
np.longlong,
np.ulonglong,
) # 64 bits
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
dtype_range = {
np.bool_: (False, True),
np.bool8: (False, True),
np.float16: (-1, 1),
np.float32: (-1, 1),
np.float64: (-1, 1),
}
dtype_range.update(_integer_ranges)
DTYPE_RANGE = dtype_range.copy()
DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items())
DTYPE_RANGE.update(
{
"uint10": (0, 2**10 - 1),
"uint12": (0, 2**12 - 1),
"uint14": (0, 2**14 - 1),
"bool": dtype_range[np.bool_],
"float": dtype_range[np.float64],
}
)
def intensity_range(image, range_values="image", clip_negative=False):
"""Return image intensity range (min, max) based on desired value type.
Parameters
----------
image : array
Input image.
range_values : str or 2-tuple, optional
The image intensity range is configured by this parameter.
The possible values for this parameter are enumerated below.
'image'
Return image min/max as the range.
'dtype'
Return min/max of the image's dtype as the range.
dtype-name
Return intensity range based on desired `dtype`. Must be valid key
in `DTYPE_RANGE`. Note: `image` is ignored for this range type.
2-tuple
Return `range_values` as min/max intensities. Note that there's no
reason to use this function if you just want to specify the
intensity range explicitly. This option is included for functions
that use `intensity_range` to support all desired range types.
clip_negative : bool, optional
If True, clip the negative range (i.e. return 0 for min intensity)
even if the image dtype allows negative values.
"""
if range_values == "dtype":
range_values = image.dtype.type
if range_values == "image":
i_min = np.min(image)
i_max = np.max(image)
elif range_values in DTYPE_RANGE:
i_min, i_max = DTYPE_RANGE[range_values]
if clip_negative:
i_min = 0
else:
i_min, i_max = range_values
return i_min, i_max
def _output_dtype(dtype_or_range):
"""Determine the output dtype for rescale_intensity.
The dtype is determined according to the following rules:
- if ``dtype_or_range`` is a dtype, that is the output dtype.
- if ``dtype_or_range`` is a dtype string, that is the dtype used, unless
it is not a NumPy data type (e.g. 'uint12' for 12-bit unsigned integers),
in which case the data type that can contain it will be used
(e.g. uint16 in this case).
- if ``dtype_or_range`` is a pair of values, the output data type will be
float.
Parameters
----------
dtype_or_range : type, string, or 2-tuple of int/float
The desired range for the output, expressed as either a NumPy dtype or
as a (min, max) pair of numbers.
Returns
-------
out_dtype : type
The data type appropriate for the desired output.
"""
if type(dtype_or_range) in [list, tuple, np.ndarray]:
# pair of values: always return float.
return np.float_
if type(dtype_or_range) == type:
# already a type: return it
return dtype_or_range
if dtype_or_range in DTYPE_RANGE:
# string key in DTYPE_RANGE dictionary
try:
# if it's a canonical numpy dtype, convert
return np.dtype(dtype_or_range).type
except TypeError: # uint10, uint12, uint14
# otherwise, return uint16
return np.uint16
else:
raise ValueError(
"Incorrect value for out_range, should be a valid image data "
"type or a pair of values, got %s." % str(dtype_or_range)
)
def rescale_intensity(image, in_range="image", out_range="dtype"):
"""Return image after stretching or shrinking its intensity levels.
The desired intensity range of the input and output, `in_range` and
`out_range` respectively, are used to stretch or shrink the intensity range
of the input image. See examples below.
Parameters
----------
image : array
Image array.
in_range, out_range : str or 2-tuple, optional
Min and max intensity values of input and output image.
The possible values for this parameter are enumerated below.
'image'
Use image min/max as the intensity range.
'dtype'
Use min/max of the image's dtype as the intensity range.
dtype-name
Use intensity range based on desired `dtype`. Must be valid key
in `DTYPE_RANGE`.
2-tuple
Use `range_values` as explicit min/max intensities.
Returns
-------
out : array
Image array after rescaling its intensity. This image is the same dtype
as the input image.
Notes
-----
.. versionchanged:: 0.17
The dtype of the output array has changed to match the output dtype, or
float if the output range is specified by a pair of floats.
See Also
--------
equalize_hist
Examples
--------
By default, the min/max intensities of the input image are stretched to
the limits allowed by the image's dtype, since `in_range` defaults to
'image' and `out_range` defaults to 'dtype':
>>> image = np.array([51, 102, 153], dtype=np.uint8)
>>> rescale_intensity(image)
array([ 0, 127, 255], dtype=uint8)
It's easy to accidentally convert an image dtype from uint8 to float:
>>> 1.0 * image
array([ 51., 102., 153.])
Use `rescale_intensity` to rescale to the proper range for float dtypes:
>>> image_float = 1.0 * image
>>> rescale_intensity(image_float)
array([0. , 0.5, 1. ])
To maintain the low contrast of the original, use the `in_range` parameter:
>>> rescale_intensity(image_float, in_range=(0, 255))
array([0.2, 0.4, 0.6])
If the min/max value of `in_range` is more/less than the min/max image
intensity, then the intensity levels are clipped:
>>> rescale_intensity(image_float, in_range=(0, 102))
array([0.5, 1. , 1. ])
If you have an image with signed integers but want to rescale the image to
just the positive range, use the `out_range` parameter. In that case, the
output dtype will be float:
>>> image = np.array([-10, 0, 10], dtype=np.int8)
>>> rescale_intensity(image, out_range=(0, 127))
array([ 0. , 63.5, 127. ])
To get the desired range with a specific dtype, use ``.astype()``:
>>> rescale_intensity(image, out_range=(0, 127)).astype(np.int8)
array([ 0, 63, 127], dtype=int8)
If the input image is constant, the output will be clipped directly to the
output range:
>>> image = np.array([130, 130, 130], dtype=np.int32)
>>> rescale_intensity(image, out_range=(0, 127)).astype(np.int32)
array([127, 127, 127], dtype=int32)
"""
if out_range in ["dtype", "image"]:
out_dtype = _output_dtype(image.dtype.type)
else:
out_dtype = _output_dtype(out_range)
imin, imax = map(float, intensity_range(image, in_range))
omin, omax = map(
float, intensity_range(image, out_range, clip_negative=(imin >= 0))
)
if np.any(np.isnan([imin, imax, omin, omax])):
warn(
"One or more intensity levels are NaN. Rescaling will broadcast "
"NaN to the full image. Provide intensity levels yourself to "
"avoid this. E.g. with np.nanmin(image), np.nanmax(image).",
stacklevel=2,
)
image = np.clip(image, imin, imax)
if imin != imax:
image = (image - imin) / (imax - imin)
return np.asarray(image * (omax - omin) + omin, dtype=out_dtype)
else:
return np.clip(image, omin, omax).astype(out_dtype)

View File

@ -0,0 +1,157 @@
"""
The `trendline_functions` module contains functions which are called by Plotly Express
when the `trendline` argument is used. Valid values for `trendline` are the names of the
functions in this module, and the value of the `trendline_options` argument to PX
functions is passed in as the first argument to these functions when called.
Note that the functions in this module are not meant to be called directly, and are
exposed as part of the public API for documentation purposes.
"""
import pandas as pd
import numpy as np
__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"]
def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Ordinary Least Squares (OLS) trendline function
Requires `statsmodels` to be installed.
This trendline function causes fit results to be stored within the figure,
accessible via the `plotly.express.get_trendline_results` function. The fit results
are the output of the `statsmodels.api.OLS` function.
Valid keys for the `trendline_options` dict are:
- `add_constant` (`bool`, default `True`): if `False`, the trendline passes through
the origin but if `True` a y-intercept is fitted.
- `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
respect to the base 10 logarithm of the input. Note that this means no zeros can
be present in the input.
"""
valid_options = ["add_constant", "log_x", "log_y"]
for k in trendline_options.keys():
if k not in valid_options:
raise ValueError(
"OLS trendline_options keys must be one of [%s] but got '%s'"
% (", ".join(valid_options), k)
)
import statsmodels.api as sm
add_constant = trendline_options.get("add_constant", True)
log_x = trendline_options.get("log_x", False)
log_y = trendline_options.get("log_y", False)
if log_y:
if np.any(y <= 0):
raise ValueError(
"Can't do OLS trendline with `log_y=True` when `y` contains non-positive values."
)
y = np.log10(y)
y_label = "log10(%s)" % y_label
if log_x:
if np.any(x <= 0):
raise ValueError(
"Can't do OLS trendline with `log_x=True` when `x` contains non-positive values."
)
x = np.log10(x)
x_label = "log10(%s)" % x_label
if add_constant:
x = sm.add_constant(x)
fit_results = sm.OLS(y, x, missing="drop").fit()
y_out = fit_results.predict()
if log_y:
y_out = np.power(10, y_out)
hover_header = "<b>OLS trendline</b><br>"
if len(fit_results.params) == 2:
hover_header += "%s = %g * %s + %g<br>" % (
y_label,
fit_results.params[1],
x_label,
fit_results.params[0],
)
elif not add_constant:
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
else:
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
return y_out, hover_header, fit_results
def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function
Requires `statsmodels` to be installed.
Valid keys for the `trendline_options` dict are:
- `frac` (`float`, default `0.6666666`): the `frac` parameter from the
`statsmodels.api.nonparametric.lowess` function
"""
valid_options = ["frac"]
for k in trendline_options.keys():
if k not in valid_options:
raise ValueError(
"LOWESS trendline_options keys must be one of [%s] but got '%s'"
% (", ".join(valid_options), k)
)
import statsmodels.api as sm
frac = trendline_options.get("frac", 0.6666666)
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
hover_header = "<b>LOWESS trendline</b><br><br>"
return y_out, hover_header, None
def _pandas(mode, trendline_options, x_raw, y, non_missing):
modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
trendline_options = trendline_options.copy()
function_name = trendline_options.pop("function", "mean")
function_args = trendline_options.pop("function_args", dict())
series = pd.Series(y, index=x_raw)
agg = getattr(series, mode) # e.g. series.rolling
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
y_out = y_out[non_missing]
hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
return y_out, hover_header, None
def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Rolling trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.rolling` function.
"""
return _pandas("rolling", trendline_options, x_raw, y, non_missing)
def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Expanding trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.expanding` function.
"""
return _pandas("expanding", trendline_options, x_raw, y, non_missing)
def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Exponentially Weighted Moment (EWM) trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.ewm` function.
"""
return _pandas("ewm", trendline_options, x_raw, y, non_missing)