mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-07-01 14:07:48 +00:00
450 lines
14 KiB
Python
450 lines
14 KiB
Python
from __future__ import absolute_import
|
|
|
|
from plotly import exceptions, optional_imports
|
|
from plotly.figure_factory import utils
|
|
from plotly.graph_objs import graph_objs
|
|
|
|
# Optional imports, may be None for users that only use our core functionality.
|
|
np = optional_imports.get_module("numpy")
|
|
pd = optional_imports.get_module("pandas")
|
|
scipy = optional_imports.get_module("scipy")
|
|
scipy_stats = optional_imports.get_module("scipy.stats")
|
|
|
|
|
|
DEFAULT_HISTNORM = "probability density"
|
|
ALTERNATIVE_HISTNORM = "probability"
|
|
|
|
|
|
def validate_distplot(hist_data, curve_type):
|
|
"""
|
|
Distplot-specific validations
|
|
|
|
:raises: (PlotlyError) If hist_data is not a list of lists
|
|
:raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
|
|
'normal').
|
|
"""
|
|
hist_data_types = (list,)
|
|
if np:
|
|
hist_data_types += (np.ndarray,)
|
|
if pd:
|
|
hist_data_types += (pd.core.series.Series,)
|
|
|
|
if not isinstance(hist_data[0], hist_data_types):
|
|
raise exceptions.PlotlyError(
|
|
"Oops, this function was written "
|
|
"to handle multiple datasets, if "
|
|
"you want to plot just one, make "
|
|
"sure your hist_data variable is "
|
|
"still a list of lists, i.e. x = "
|
|
"[1, 2, 3] -> x = [[1, 2, 3]]"
|
|
)
|
|
|
|
curve_opts = ("kde", "normal")
|
|
if curve_type not in curve_opts:
|
|
raise exceptions.PlotlyError(
|
|
"curve_type must be defined as " "'kde' or 'normal'"
|
|
)
|
|
|
|
if not scipy:
|
|
raise ImportError("FigureFactory.create_distplot requires scipy")
|
|
|
|
|
|
def create_distplot(
|
|
hist_data,
|
|
group_labels,
|
|
bin_size=1.0,
|
|
curve_type="kde",
|
|
colors=None,
|
|
rug_text=None,
|
|
histnorm=DEFAULT_HISTNORM,
|
|
show_hist=True,
|
|
show_curve=True,
|
|
show_rug=True,
|
|
):
|
|
"""
|
|
Function that creates a distplot similar to seaborn.distplot;
|
|
**this function is deprecated**, use instead :mod:`plotly.express`
|
|
functions, for example
|
|
|
|
>>> import plotly.express as px
|
|
>>> tips = px.data.tips()
|
|
>>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
|
|
... hover_data=tips.columns)
|
|
>>> fig.show()
|
|
|
|
|
|
The distplot can be composed of all or any combination of the following
|
|
3 components: (1) histogram, (2) curve: (a) kernel density estimation
|
|
or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
|
|
(from multiple datasets) can be created in the same plot.
|
|
|
|
:param (list[list]) hist_data: Use list of lists to plot multiple data
|
|
sets on the same plot.
|
|
:param (list[str]) group_labels: Names for each data set.
|
|
:param (list[float]|float) bin_size: Size of histogram bins.
|
|
Default = 1.
|
|
:param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
|
|
:param (str) histnorm: 'probability density' or 'probability'
|
|
Default = 'probability density'
|
|
:param (bool) show_hist: Add histogram to distplot? Default = True
|
|
:param (bool) show_curve: Add curve to distplot? Default = True
|
|
:param (bool) show_rug: Add rug to distplot? Default = True
|
|
:param (list[str]) colors: Colors for traces.
|
|
:param (list[list]) rug_text: Hovertext values for rug_plot,
|
|
:return (dict): Representation of a distplot figure.
|
|
|
|
Example 1: Simple distplot of 1 data set
|
|
|
|
>>> from plotly.figure_factory import create_distplot
|
|
|
|
>>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
|
|
... 3.5, 4.1, 4.4, 4.5, 4.5,
|
|
... 5.0, 5.0, 5.2, 5.5, 5.5,
|
|
... 5.5, 5.5, 5.5, 6.1, 7.0]]
|
|
>>> group_labels = ['distplot example']
|
|
>>> fig = create_distplot(hist_data, group_labels)
|
|
>>> fig.show()
|
|
|
|
|
|
Example 2: Two data sets and added rug text
|
|
|
|
>>> from plotly.figure_factory import create_distplot
|
|
>>> # Add histogram data
|
|
>>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
|
|
... -0.9, -0.07, 1.95, 0.9, -0.2,
|
|
... -0.5, 0.3, 0.4, -0.37, 0.6]
|
|
>>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
|
|
... 1.0, 0.8, 1.7, 0.5, 0.8,
|
|
... -0.3, 1.2, 0.56, 0.3, 2.2]
|
|
|
|
>>> # Group data together
|
|
>>> hist_data = [hist1_x, hist2_x]
|
|
|
|
>>> group_labels = ['2012', '2013']
|
|
|
|
>>> # Add text
|
|
>>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
|
|
... 'f1', 'g1', 'h1', 'i1', 'j1',
|
|
... 'k1', 'l1', 'm1', 'n1', 'o1']
|
|
|
|
>>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
|
|
... 'f2', 'g2', 'h2', 'i2', 'j2',
|
|
... 'k2', 'l2', 'm2', 'n2', 'o2']
|
|
|
|
>>> # Group text together
|
|
>>> rug_text_all = [rug_text_1, rug_text_2]
|
|
|
|
>>> # Create distplot
|
|
>>> fig = create_distplot(
|
|
... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
|
|
|
|
>>> # Add title
|
|
>>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
|
|
>>> fig.show()
|
|
|
|
|
|
Example 3: Plot with normal curve and hide rug plot
|
|
|
|
>>> from plotly.figure_factory import create_distplot
|
|
>>> import numpy as np
|
|
|
|
>>> x1 = np.random.randn(190)
|
|
>>> x2 = np.random.randn(200)+1
|
|
>>> x3 = np.random.randn(200)-1
|
|
>>> x4 = np.random.randn(210)+2
|
|
|
|
>>> hist_data = [x1, x2, x3, x4]
|
|
>>> group_labels = ['2012', '2013', '2014', '2015']
|
|
|
|
>>> fig = create_distplot(
|
|
... hist_data, group_labels, curve_type='normal',
|
|
... show_rug=False, bin_size=.4)
|
|
|
|
|
|
Example 4: Distplot with Pandas
|
|
|
|
>>> from plotly.figure_factory import create_distplot
|
|
>>> import numpy as np
|
|
>>> import pandas as pd
|
|
|
|
>>> df = pd.DataFrame({'2012': np.random.randn(200),
|
|
... '2013': np.random.randn(200)+1})
|
|
>>> fig = create_distplot([df[c] for c in df.columns], df.columns)
|
|
>>> fig.show()
|
|
"""
|
|
if colors is None:
|
|
colors = []
|
|
if rug_text is None:
|
|
rug_text = []
|
|
|
|
validate_distplot(hist_data, curve_type)
|
|
utils.validate_equal_length(hist_data, group_labels)
|
|
|
|
if isinstance(bin_size, (float, int)):
|
|
bin_size = [bin_size] * len(hist_data)
|
|
|
|
data = []
|
|
if show_hist:
|
|
|
|
hist = _Distplot(
|
|
hist_data,
|
|
histnorm,
|
|
group_labels,
|
|
bin_size,
|
|
curve_type,
|
|
colors,
|
|
rug_text,
|
|
show_hist,
|
|
show_curve,
|
|
).make_hist()
|
|
|
|
data.append(hist)
|
|
|
|
if show_curve:
|
|
|
|
if curve_type == "normal":
|
|
curve = _Distplot(
|
|
hist_data,
|
|
histnorm,
|
|
group_labels,
|
|
bin_size,
|
|
curve_type,
|
|
colors,
|
|
rug_text,
|
|
show_hist,
|
|
show_curve,
|
|
).make_normal()
|
|
else:
|
|
curve = _Distplot(
|
|
hist_data,
|
|
histnorm,
|
|
group_labels,
|
|
bin_size,
|
|
curve_type,
|
|
colors,
|
|
rug_text,
|
|
show_hist,
|
|
show_curve,
|
|
).make_kde()
|
|
|
|
data.append(curve)
|
|
|
|
if show_rug:
|
|
|
|
rug = _Distplot(
|
|
hist_data,
|
|
histnorm,
|
|
group_labels,
|
|
bin_size,
|
|
curve_type,
|
|
colors,
|
|
rug_text,
|
|
show_hist,
|
|
show_curve,
|
|
).make_rug()
|
|
|
|
data.append(rug)
|
|
layout = graph_objs.Layout(
|
|
barmode="overlay",
|
|
hovermode="closest",
|
|
legend=dict(traceorder="reversed"),
|
|
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
|
|
yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
|
|
yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
|
|
)
|
|
else:
|
|
layout = graph_objs.Layout(
|
|
barmode="overlay",
|
|
hovermode="closest",
|
|
legend=dict(traceorder="reversed"),
|
|
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
|
|
yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
|
|
)
|
|
|
|
data = sum(data, [])
|
|
return graph_objs.Figure(data=data, layout=layout)
|
|
|
|
|
|
class _Distplot(object):
|
|
"""
|
|
Refer to TraceFactory.create_distplot() for docstring
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hist_data,
|
|
histnorm,
|
|
group_labels,
|
|
bin_size,
|
|
curve_type,
|
|
colors,
|
|
rug_text,
|
|
show_hist,
|
|
show_curve,
|
|
):
|
|
self.hist_data = hist_data
|
|
self.histnorm = histnorm
|
|
self.group_labels = group_labels
|
|
self.bin_size = bin_size
|
|
self.show_hist = show_hist
|
|
self.show_curve = show_curve
|
|
self.trace_number = len(hist_data)
|
|
if rug_text:
|
|
self.rug_text = rug_text
|
|
else:
|
|
self.rug_text = [None] * self.trace_number
|
|
|
|
self.start = []
|
|
self.end = []
|
|
if colors:
|
|
self.colors = colors
|
|
else:
|
|
self.colors = [
|
|
"rgb(31, 119, 180)",
|
|
"rgb(255, 127, 14)",
|
|
"rgb(44, 160, 44)",
|
|
"rgb(214, 39, 40)",
|
|
"rgb(148, 103, 189)",
|
|
"rgb(140, 86, 75)",
|
|
"rgb(227, 119, 194)",
|
|
"rgb(127, 127, 127)",
|
|
"rgb(188, 189, 34)",
|
|
"rgb(23, 190, 207)",
|
|
]
|
|
self.curve_x = [None] * self.trace_number
|
|
self.curve_y = [None] * self.trace_number
|
|
|
|
for trace in self.hist_data:
|
|
self.start.append(min(trace) * 1.0)
|
|
self.end.append(max(trace) * 1.0)
|
|
|
|
def make_hist(self):
|
|
"""
|
|
Makes the histogram(s) for FigureFactory.create_distplot().
|
|
|
|
:rtype (list) hist: list of histogram representations
|
|
"""
|
|
hist = [None] * self.trace_number
|
|
|
|
for index in range(self.trace_number):
|
|
hist[index] = dict(
|
|
type="histogram",
|
|
x=self.hist_data[index],
|
|
xaxis="x1",
|
|
yaxis="y1",
|
|
histnorm=self.histnorm,
|
|
name=self.group_labels[index],
|
|
legendgroup=self.group_labels[index],
|
|
marker=dict(color=self.colors[index % len(self.colors)]),
|
|
autobinx=False,
|
|
xbins=dict(
|
|
start=self.start[index],
|
|
end=self.end[index],
|
|
size=self.bin_size[index],
|
|
),
|
|
opacity=0.7,
|
|
)
|
|
return hist
|
|
|
|
def make_kde(self):
|
|
"""
|
|
Makes the kernel density estimation(s) for create_distplot().
|
|
|
|
This is called when curve_type = 'kde' in create_distplot().
|
|
|
|
:rtype (list) curve: list of kde representations
|
|
"""
|
|
curve = [None] * self.trace_number
|
|
for index in range(self.trace_number):
|
|
self.curve_x[index] = [
|
|
self.start[index] + x * (self.end[index] - self.start[index]) / 500
|
|
for x in range(500)
|
|
]
|
|
self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
|
|
self.curve_x[index]
|
|
)
|
|
|
|
if self.histnorm == ALTERNATIVE_HISTNORM:
|
|
self.curve_y[index] *= self.bin_size[index]
|
|
|
|
for index in range(self.trace_number):
|
|
curve[index] = dict(
|
|
type="scatter",
|
|
x=self.curve_x[index],
|
|
y=self.curve_y[index],
|
|
xaxis="x1",
|
|
yaxis="y1",
|
|
mode="lines",
|
|
name=self.group_labels[index],
|
|
legendgroup=self.group_labels[index],
|
|
showlegend=False if self.show_hist else True,
|
|
marker=dict(color=self.colors[index % len(self.colors)]),
|
|
)
|
|
return curve
|
|
|
|
def make_normal(self):
|
|
"""
|
|
Makes the normal curve(s) for create_distplot().
|
|
|
|
This is called when curve_type = 'normal' in create_distplot().
|
|
|
|
:rtype (list) curve: list of normal curve representations
|
|
"""
|
|
curve = [None] * self.trace_number
|
|
mean = [None] * self.trace_number
|
|
sd = [None] * self.trace_number
|
|
|
|
for index in range(self.trace_number):
|
|
mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
|
|
self.curve_x[index] = [
|
|
self.start[index] + x * (self.end[index] - self.start[index]) / 500
|
|
for x in range(500)
|
|
]
|
|
self.curve_y[index] = scipy_stats.norm.pdf(
|
|
self.curve_x[index], loc=mean[index], scale=sd[index]
|
|
)
|
|
|
|
if self.histnorm == ALTERNATIVE_HISTNORM:
|
|
self.curve_y[index] *= self.bin_size[index]
|
|
|
|
for index in range(self.trace_number):
|
|
curve[index] = dict(
|
|
type="scatter",
|
|
x=self.curve_x[index],
|
|
y=self.curve_y[index],
|
|
xaxis="x1",
|
|
yaxis="y1",
|
|
mode="lines",
|
|
name=self.group_labels[index],
|
|
legendgroup=self.group_labels[index],
|
|
showlegend=False if self.show_hist else True,
|
|
marker=dict(color=self.colors[index % len(self.colors)]),
|
|
)
|
|
return curve
|
|
|
|
def make_rug(self):
|
|
"""
|
|
Makes the rug plot(s) for create_distplot().
|
|
|
|
:rtype (list) rug: list of rug plot representations
|
|
"""
|
|
rug = [None] * self.trace_number
|
|
for index in range(self.trace_number):
|
|
|
|
rug[index] = dict(
|
|
type="scatter",
|
|
x=self.hist_data[index],
|
|
y=([self.group_labels[index]] * len(self.hist_data[index])),
|
|
xaxis="x1",
|
|
yaxis="y2",
|
|
mode="markers",
|
|
name=self.group_labels[index],
|
|
legendgroup=self.group_labels[index],
|
|
showlegend=(False if self.show_hist or self.show_curve else True),
|
|
text=self.rug_text[index],
|
|
marker=dict(
|
|
color=self.colors[index % len(self.colors)], symbol="line-ns-open"
|
|
),
|
|
)
|
|
return rug
|