Updated script that can be controled by Nodejs web app
This commit is contained in:
98
lib/python3.13/site-packages/pandas/plotting/__init__.py
Normal file
98
lib/python3.13/site-packages/pandas/plotting/__init__.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Plotting public API.
|
||||
|
||||
Authors of third-party plotting backends should implement a module with a
|
||||
public ``plot(data, kind, **kwargs)``. The parameter `data` will contain
|
||||
the data structure and can be a `Series` or a `DataFrame`. For example,
|
||||
for ``df.plot()`` the parameter `data` will contain the DataFrame `df`.
|
||||
In some cases, the data structure is transformed before being sent to
|
||||
the backend (see PlotAccessor.__call__ in pandas/plotting/_core.py for
|
||||
the exact transformations).
|
||||
|
||||
The parameter `kind` will be one of:
|
||||
|
||||
- line
|
||||
- bar
|
||||
- barh
|
||||
- box
|
||||
- hist
|
||||
- kde
|
||||
- area
|
||||
- pie
|
||||
- scatter
|
||||
- hexbin
|
||||
|
||||
See the pandas API reference for documentation on each kind of plot.
|
||||
|
||||
Any other keyword argument is currently assumed to be backend specific,
|
||||
but some parameters may be unified and added to the signature in the
|
||||
future (e.g. `title` which should be useful for any backend).
|
||||
|
||||
Currently, all the Matplotlib functions in pandas are accessed through
|
||||
the selected backend. For example, `pandas.plotting.boxplot` (equivalent
|
||||
to `DataFrame.boxplot`) is also accessed in the selected backend. This
|
||||
is expected to change, and the exact API is under discussion. But with
|
||||
the current version, backends are expected to implement the next functions:
|
||||
|
||||
- plot (describe above, used for `Series.plot` and `DataFrame.plot`)
|
||||
- hist_series and hist_frame (for `Series.hist` and `DataFrame.hist`)
|
||||
- boxplot (`pandas.plotting.boxplot(df)` equivalent to `DataFrame.boxplot`)
|
||||
- boxplot_frame and boxplot_frame_groupby
|
||||
- register and deregister (register converters for the tick formats)
|
||||
- Plots not called as `Series` and `DataFrame` methods:
|
||||
- table
|
||||
- andrews_curves
|
||||
- autocorrelation_plot
|
||||
- bootstrap_plot
|
||||
- lag_plot
|
||||
- parallel_coordinates
|
||||
- radviz
|
||||
- scatter_matrix
|
||||
|
||||
Use the code in pandas/plotting/_matplotib.py and
|
||||
https://github.com/pyviz/hvplot as a reference on how to write a backend.
|
||||
|
||||
For the discussion about the API see
|
||||
https://github.com/pandas-dev/pandas/issues/26747.
|
||||
"""
|
||||
from pandas.plotting._core import (
|
||||
PlotAccessor,
|
||||
boxplot,
|
||||
boxplot_frame,
|
||||
boxplot_frame_groupby,
|
||||
hist_frame,
|
||||
hist_series,
|
||||
)
|
||||
from pandas.plotting._misc import (
|
||||
andrews_curves,
|
||||
autocorrelation_plot,
|
||||
bootstrap_plot,
|
||||
deregister as deregister_matplotlib_converters,
|
||||
lag_plot,
|
||||
parallel_coordinates,
|
||||
plot_params,
|
||||
radviz,
|
||||
register as register_matplotlib_converters,
|
||||
scatter_matrix,
|
||||
table,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PlotAccessor",
|
||||
"boxplot",
|
||||
"boxplot_frame",
|
||||
"boxplot_frame_groupby",
|
||||
"hist_frame",
|
||||
"hist_series",
|
||||
"scatter_matrix",
|
||||
"radviz",
|
||||
"andrews_curves",
|
||||
"bootstrap_plot",
|
||||
"parallel_coordinates",
|
||||
"lag_plot",
|
||||
"autocorrelation_plot",
|
||||
"table",
|
||||
"plot_params",
|
||||
"register_matplotlib_converters",
|
||||
"deregister_matplotlib_converters",
|
||||
]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
1946
lib/python3.13/site-packages/pandas/plotting/_core.py
Normal file
1946
lib/python3.13/site-packages/pandas/plotting/_core.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pandas.plotting._matplotlib.boxplot import (
|
||||
BoxPlot,
|
||||
boxplot,
|
||||
boxplot_frame,
|
||||
boxplot_frame_groupby,
|
||||
)
|
||||
from pandas.plotting._matplotlib.converter import (
|
||||
deregister,
|
||||
register,
|
||||
)
|
||||
from pandas.plotting._matplotlib.core import (
|
||||
AreaPlot,
|
||||
BarhPlot,
|
||||
BarPlot,
|
||||
HexBinPlot,
|
||||
LinePlot,
|
||||
PiePlot,
|
||||
ScatterPlot,
|
||||
)
|
||||
from pandas.plotting._matplotlib.hist import (
|
||||
HistPlot,
|
||||
KdePlot,
|
||||
hist_frame,
|
||||
hist_series,
|
||||
)
|
||||
from pandas.plotting._matplotlib.misc import (
|
||||
andrews_curves,
|
||||
autocorrelation_plot,
|
||||
bootstrap_plot,
|
||||
lag_plot,
|
||||
parallel_coordinates,
|
||||
radviz,
|
||||
scatter_matrix,
|
||||
)
|
||||
from pandas.plotting._matplotlib.tools import table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas.plotting._matplotlib.core import MPLPlot
|
||||
|
||||
PLOT_CLASSES: dict[str, type[MPLPlot]] = {
|
||||
"line": LinePlot,
|
||||
"bar": BarPlot,
|
||||
"barh": BarhPlot,
|
||||
"box": BoxPlot,
|
||||
"hist": HistPlot,
|
||||
"kde": KdePlot,
|
||||
"area": AreaPlot,
|
||||
"pie": PiePlot,
|
||||
"scatter": ScatterPlot,
|
||||
"hexbin": HexBinPlot,
|
||||
}
|
||||
|
||||
|
||||
def plot(data, kind, **kwargs):
|
||||
# Importing pyplot at the top of the file (before the converters are
|
||||
# registered) causes problems in matplotlib 2 (converters seem to not
|
||||
# work)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if kwargs.pop("reuse_plot", False):
|
||||
ax = kwargs.get("ax")
|
||||
if ax is None and len(plt.get_fignums()) > 0:
|
||||
with plt.rc_context():
|
||||
ax = plt.gca()
|
||||
kwargs["ax"] = getattr(ax, "left_ax", ax)
|
||||
plot_obj = PLOT_CLASSES[kind](data, **kwargs)
|
||||
plot_obj.generate()
|
||||
plot_obj.draw()
|
||||
return plot_obj.result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"plot",
|
||||
"hist_series",
|
||||
"hist_frame",
|
||||
"boxplot",
|
||||
"boxplot_frame",
|
||||
"boxplot_frame_groupby",
|
||||
"table",
|
||||
"andrews_curves",
|
||||
"autocorrelation_plot",
|
||||
"bootstrap_plot",
|
||||
"lag_plot",
|
||||
"parallel_coordinates",
|
||||
"radviz",
|
||||
"scatter_matrix",
|
||||
"register",
|
||||
"deregister",
|
||||
]
|
@ -0,0 +1,572 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
)
|
||||
import warnings
|
||||
|
||||
from matplotlib.artist import setp
|
||||
import numpy as np
|
||||
|
||||
from pandas._libs import lib
|
||||
from pandas.util._decorators import cache_readonly
|
||||
from pandas.util._exceptions import find_stack_level
|
||||
|
||||
from pandas.core.dtypes.common import is_dict_like
|
||||
from pandas.core.dtypes.generic import ABCSeries
|
||||
from pandas.core.dtypes.missing import remove_na_arraylike
|
||||
|
||||
import pandas as pd
|
||||
import pandas.core.common as com
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.core import (
|
||||
LinePlot,
|
||||
MPLPlot,
|
||||
)
|
||||
from pandas.plotting._matplotlib.groupby import create_iter_data_given_by
|
||||
from pandas.plotting._matplotlib.style import get_standard_colors
|
||||
from pandas.plotting._matplotlib.tools import (
|
||||
create_subplots,
|
||||
flatten_axes,
|
||||
maybe_adjust_figure,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
from pandas._typing import MatplotlibColor
|
||||
|
||||
|
||||
def _set_ticklabels(ax: Axes, labels: list[str], is_vertical: bool, **kwargs) -> None:
|
||||
"""Set the tick labels of a given axis.
|
||||
|
||||
Due to https://github.com/matplotlib/matplotlib/pull/17266, we need to handle the
|
||||
case of repeated ticks (due to `FixedLocator`) and thus we duplicate the number of
|
||||
labels.
|
||||
"""
|
||||
ticks = ax.get_xticks() if is_vertical else ax.get_yticks()
|
||||
if len(ticks) != len(labels):
|
||||
i, remainder = divmod(len(ticks), len(labels))
|
||||
assert remainder == 0, remainder
|
||||
labels *= i
|
||||
if is_vertical:
|
||||
ax.set_xticklabels(labels, **kwargs)
|
||||
else:
|
||||
ax.set_yticklabels(labels, **kwargs)
|
||||
|
||||
|
||||
class BoxPlot(LinePlot):
|
||||
@property
|
||||
def _kind(self) -> Literal["box"]:
|
||||
return "box"
|
||||
|
||||
_layout_type = "horizontal"
|
||||
|
||||
_valid_return_types = (None, "axes", "dict", "both")
|
||||
|
||||
class BP(NamedTuple):
|
||||
# namedtuple to hold results
|
||||
ax: Axes
|
||||
lines: dict[str, list[Line2D]]
|
||||
|
||||
def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
|
||||
if return_type not in self._valid_return_types:
|
||||
raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}")
|
||||
|
||||
self.return_type = return_type
|
||||
# Do not call LinePlot.__init__ which may fill nan
|
||||
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
|
||||
|
||||
if self.subplots:
|
||||
# Disable label ax sharing. Otherwise, all subplots shows last
|
||||
# column label
|
||||
if self.orientation == "vertical":
|
||||
self.sharex = False
|
||||
else:
|
||||
self.sharey = False
|
||||
|
||||
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
|
||||
@classmethod
|
||||
def _plot( # type: ignore[override]
|
||||
cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds
|
||||
):
|
||||
ys: np.ndarray | list[np.ndarray]
|
||||
if y.ndim == 2:
|
||||
ys = [remove_na_arraylike(v) for v in y]
|
||||
# Boxplot fails with empty arrays, so need to add a NaN
|
||||
# if any cols are empty
|
||||
# GH 8181
|
||||
ys = [v if v.size > 0 else np.array([np.nan]) for v in ys]
|
||||
else:
|
||||
ys = remove_na_arraylike(y)
|
||||
bp = ax.boxplot(ys, **kwds)
|
||||
|
||||
if return_type == "dict":
|
||||
return bp, bp
|
||||
elif return_type == "both":
|
||||
return cls.BP(ax=ax, lines=bp), bp
|
||||
else:
|
||||
return ax, bp
|
||||
|
||||
def _validate_color_args(self, color, colormap):
|
||||
if color is lib.no_default:
|
||||
return None
|
||||
|
||||
if colormap is not None:
|
||||
warnings.warn(
|
||||
"'color' and 'colormap' cannot be used "
|
||||
"simultaneously. Using 'color'",
|
||||
stacklevel=find_stack_level(),
|
||||
)
|
||||
|
||||
if isinstance(color, dict):
|
||||
valid_keys = ["boxes", "whiskers", "medians", "caps"]
|
||||
for key in color:
|
||||
if key not in valid_keys:
|
||||
raise ValueError(
|
||||
f"color dict contains invalid key '{key}'. "
|
||||
f"The key must be either {valid_keys}"
|
||||
)
|
||||
return color
|
||||
|
||||
@cache_readonly
|
||||
def _color_attrs(self):
|
||||
# get standard colors for default
|
||||
# use 2 colors by default, for box/whisker and median
|
||||
# flier colors isn't needed here
|
||||
# because it can be specified by ``sym`` kw
|
||||
return get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
|
||||
|
||||
@cache_readonly
|
||||
def _boxes_c(self):
|
||||
return self._color_attrs[0]
|
||||
|
||||
@cache_readonly
|
||||
def _whiskers_c(self):
|
||||
return self._color_attrs[0]
|
||||
|
||||
@cache_readonly
|
||||
def _medians_c(self):
|
||||
return self._color_attrs[2]
|
||||
|
||||
@cache_readonly
|
||||
def _caps_c(self):
|
||||
return self._color_attrs[0]
|
||||
|
||||
def _get_colors(
|
||||
self,
|
||||
num_colors=None,
|
||||
color_kwds: dict[str, MatplotlibColor]
|
||||
| MatplotlibColor
|
||||
| Collection[MatplotlibColor]
|
||||
| None = "color",
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def maybe_color_bp(self, bp) -> None:
|
||||
if isinstance(self.color, dict):
|
||||
boxes = self.color.get("boxes", self._boxes_c)
|
||||
whiskers = self.color.get("whiskers", self._whiskers_c)
|
||||
medians = self.color.get("medians", self._medians_c)
|
||||
caps = self.color.get("caps", self._caps_c)
|
||||
else:
|
||||
# Other types are forwarded to matplotlib
|
||||
# If None, use default colors
|
||||
boxes = self.color or self._boxes_c
|
||||
whiskers = self.color or self._whiskers_c
|
||||
medians = self.color or self._medians_c
|
||||
caps = self.color or self._caps_c
|
||||
|
||||
color_tup = (boxes, whiskers, medians, caps)
|
||||
maybe_color_bp(bp, color_tup=color_tup, **self.kwds)
|
||||
|
||||
def _make_plot(self, fig: Figure) -> None:
|
||||
if self.subplots:
|
||||
self._return_obj = pd.Series(dtype=object)
|
||||
|
||||
# Re-create iterated data if `by` is assigned by users
|
||||
data = (
|
||||
create_iter_data_given_by(self.data, self._kind)
|
||||
if self.by is not None
|
||||
else self.data
|
||||
)
|
||||
|
||||
# error: Argument "data" to "_iter_data" of "MPLPlot" has
|
||||
# incompatible type "object"; expected "DataFrame |
|
||||
# dict[Hashable, Series | DataFrame]"
|
||||
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
|
||||
ax = self._get_ax(i)
|
||||
kwds = self.kwds.copy()
|
||||
|
||||
# When by is applied, show title for subplots to know which group it is
|
||||
# just like df.boxplot, and need to apply T on y to provide right input
|
||||
if self.by is not None:
|
||||
y = y.T
|
||||
ax.set_title(pprint_thing(label))
|
||||
|
||||
# When `by` is assigned, the ticklabels will become unique grouped
|
||||
# values, instead of label which is used as subtitle in this case.
|
||||
# error: "Index" has no attribute "levels"; maybe "nlevels"?
|
||||
levels = self.data.columns.levels # type: ignore[attr-defined]
|
||||
ticklabels = [pprint_thing(col) for col in levels[0]]
|
||||
else:
|
||||
ticklabels = [pprint_thing(label)]
|
||||
|
||||
ret, bp = self._plot(
|
||||
ax, y, column_num=i, return_type=self.return_type, **kwds
|
||||
)
|
||||
self.maybe_color_bp(bp)
|
||||
self._return_obj[label] = ret
|
||||
_set_ticklabels(
|
||||
ax=ax, labels=ticklabels, is_vertical=self.orientation == "vertical"
|
||||
)
|
||||
else:
|
||||
y = self.data.values.T
|
||||
ax = self._get_ax(0)
|
||||
kwds = self.kwds.copy()
|
||||
|
||||
ret, bp = self._plot(
|
||||
ax, y, column_num=0, return_type=self.return_type, **kwds
|
||||
)
|
||||
self.maybe_color_bp(bp)
|
||||
self._return_obj = ret
|
||||
|
||||
labels = [pprint_thing(left) for left in self.data.columns]
|
||||
if not self.use_index:
|
||||
labels = [pprint_thing(key) for key in range(len(labels))]
|
||||
_set_ticklabels(
|
||||
ax=ax, labels=labels, is_vertical=self.orientation == "vertical"
|
||||
)
|
||||
|
||||
def _make_legend(self) -> None:
|
||||
pass
|
||||
|
||||
def _post_plot_logic(self, ax: Axes, data) -> None:
|
||||
# GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel
|
||||
if self.xlabel:
|
||||
ax.set_xlabel(pprint_thing(self.xlabel))
|
||||
if self.ylabel:
|
||||
ax.set_ylabel(pprint_thing(self.ylabel))
|
||||
|
||||
@property
|
||||
def orientation(self) -> Literal["horizontal", "vertical"]:
|
||||
if self.kwds.get("vert", True):
|
||||
return "vertical"
|
||||
else:
|
||||
return "horizontal"
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
if self.return_type is None:
|
||||
return super().result
|
||||
else:
|
||||
return self._return_obj
|
||||
|
||||
|
||||
def maybe_color_bp(bp, color_tup, **kwds) -> None:
|
||||
# GH#30346, when users specifying those arguments explicitly, our defaults
|
||||
# for these four kwargs should be overridden; if not, use Pandas settings
|
||||
if not kwds.get("boxprops"):
|
||||
setp(bp["boxes"], color=color_tup[0], alpha=1)
|
||||
if not kwds.get("whiskerprops"):
|
||||
setp(bp["whiskers"], color=color_tup[1], alpha=1)
|
||||
if not kwds.get("medianprops"):
|
||||
setp(bp["medians"], color=color_tup[2], alpha=1)
|
||||
if not kwds.get("capprops"):
|
||||
setp(bp["caps"], color=color_tup[3], alpha=1)
|
||||
|
||||
|
||||
def _grouped_plot_by_column(
|
||||
plotf,
|
||||
data,
|
||||
columns=None,
|
||||
by=None,
|
||||
numeric_only: bool = True,
|
||||
grid: bool = False,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
ax=None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwargs,
|
||||
):
|
||||
grouped = data.groupby(by, observed=False)
|
||||
if columns is None:
|
||||
if not isinstance(by, (list, tuple)):
|
||||
by = [by]
|
||||
columns = data._get_numeric_data().columns.difference(by)
|
||||
naxes = len(columns)
|
||||
fig, axes = create_subplots(
|
||||
naxes=naxes,
|
||||
sharex=kwargs.pop("sharex", True),
|
||||
sharey=kwargs.pop("sharey", True),
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
_axes = flatten_axes(axes)
|
||||
|
||||
# GH 45465: move the "by" label based on "vert"
|
||||
xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
|
||||
if kwargs.get("vert", True):
|
||||
xlabel = xlabel or by
|
||||
else:
|
||||
ylabel = ylabel or by
|
||||
|
||||
ax_values = []
|
||||
|
||||
for i, col in enumerate(columns):
|
||||
ax = _axes[i]
|
||||
gp_col = grouped[col]
|
||||
keys, values = zip(*gp_col)
|
||||
re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
|
||||
ax.set_title(col)
|
||||
ax_values.append(re_plotf)
|
||||
ax.grid(grid)
|
||||
|
||||
result = pd.Series(ax_values, index=columns, copy=False)
|
||||
|
||||
# Return axes in multiplot case, maybe revisit later # 985
|
||||
if return_type is None:
|
||||
result = axes
|
||||
|
||||
byline = by[0] if len(by) == 1 else by
|
||||
fig.suptitle(f"Boxplot grouped by {byline}")
|
||||
maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def boxplot(
|
||||
data,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
fontsize: int | None = None,
|
||||
rot: int = 0,
|
||||
grid: bool = True,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwds,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# validate return_type:
|
||||
if return_type not in BoxPlot._valid_return_types:
|
||||
raise ValueError("return_type must be {'axes', 'dict', 'both'}")
|
||||
|
||||
if isinstance(data, ABCSeries):
|
||||
data = data.to_frame("x")
|
||||
column = "x"
|
||||
|
||||
def _get_colors():
|
||||
# num_colors=3 is required as method maybe_color_bp takes the colors
|
||||
# in positions 0 and 2.
|
||||
# if colors not provided, use same defaults as DataFrame.plot.box
|
||||
result = get_standard_colors(num_colors=3)
|
||||
result = np.take(result, [0, 0, 2])
|
||||
result = np.append(result, "k")
|
||||
|
||||
colors = kwds.pop("color", None)
|
||||
if colors:
|
||||
if is_dict_like(colors):
|
||||
# replace colors in result array with user-specified colors
|
||||
# taken from the colors dict parameter
|
||||
# "boxes" value placed in position 0, "whiskers" in 1, etc.
|
||||
valid_keys = ["boxes", "whiskers", "medians", "caps"]
|
||||
key_to_index = dict(zip(valid_keys, range(4)))
|
||||
for key, value in colors.items():
|
||||
if key in valid_keys:
|
||||
result[key_to_index[key]] = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"color dict contains invalid key '{key}'. "
|
||||
f"The key must be either {valid_keys}"
|
||||
)
|
||||
else:
|
||||
result.fill(colors)
|
||||
|
||||
return result
|
||||
|
||||
def plot_group(keys, values, ax: Axes, **kwds):
|
||||
# GH 45465: xlabel/ylabel need to be popped out before plotting happens
|
||||
xlabel, ylabel = kwds.pop("xlabel", None), kwds.pop("ylabel", None)
|
||||
if xlabel:
|
||||
ax.set_xlabel(pprint_thing(xlabel))
|
||||
if ylabel:
|
||||
ax.set_ylabel(pprint_thing(ylabel))
|
||||
|
||||
keys = [pprint_thing(x) for x in keys]
|
||||
values = [np.asarray(remove_na_arraylike(v), dtype=object) for v in values]
|
||||
bp = ax.boxplot(values, **kwds)
|
||||
if fontsize is not None:
|
||||
ax.tick_params(axis="both", labelsize=fontsize)
|
||||
|
||||
# GH 45465: x/y are flipped when "vert" changes
|
||||
_set_ticklabels(
|
||||
ax=ax, labels=keys, is_vertical=kwds.get("vert", True), rotation=rot
|
||||
)
|
||||
maybe_color_bp(bp, color_tup=colors, **kwds)
|
||||
|
||||
# Return axes in multiplot case, maybe revisit later # 985
|
||||
if return_type == "dict":
|
||||
return bp
|
||||
elif return_type == "both":
|
||||
return BoxPlot.BP(ax=ax, lines=bp)
|
||||
else:
|
||||
return ax
|
||||
|
||||
colors = _get_colors()
|
||||
if column is None:
|
||||
columns = None
|
||||
elif isinstance(column, (list, tuple)):
|
||||
columns = column
|
||||
else:
|
||||
columns = [column]
|
||||
|
||||
if by is not None:
|
||||
# Prefer array return type for 2-D plots to match the subplot layout
|
||||
# https://github.com/pandas-dev/pandas/pull/12216#issuecomment-241175580
|
||||
result = _grouped_plot_by_column(
|
||||
plot_group,
|
||||
data,
|
||||
columns=columns,
|
||||
by=by,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
layout=layout,
|
||||
return_type=return_type,
|
||||
**kwds,
|
||||
)
|
||||
else:
|
||||
if return_type is None:
|
||||
return_type = "axes"
|
||||
if layout is not None:
|
||||
raise ValueError("The 'layout' keyword is not supported when 'by' is None")
|
||||
|
||||
if ax is None:
|
||||
rc = {"figure.figsize": figsize} if figsize is not None else {}
|
||||
with plt.rc_context(rc):
|
||||
ax = plt.gca()
|
||||
data = data._get_numeric_data()
|
||||
naxes = len(data.columns)
|
||||
if naxes == 0:
|
||||
raise ValueError(
|
||||
"boxplot method requires numerical columns, nothing to plot."
|
||||
)
|
||||
if columns is None:
|
||||
columns = data.columns
|
||||
else:
|
||||
data = data[columns]
|
||||
|
||||
result = plot_group(columns, data.values.T, ax, **kwds)
|
||||
ax.grid(grid)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def boxplot_frame(
|
||||
self,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
fontsize: int | None = None,
|
||||
rot: int = 0,
|
||||
grid: bool = True,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwds,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
ax = boxplot(
|
||||
self,
|
||||
column=column,
|
||||
by=by,
|
||||
ax=ax,
|
||||
fontsize=fontsize,
|
||||
grid=grid,
|
||||
rot=rot,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
return_type=return_type,
|
||||
**kwds,
|
||||
)
|
||||
plt.draw_if_interactive()
|
||||
return ax
|
||||
|
||||
|
||||
def boxplot_frame_groupby(
|
||||
grouped,
|
||||
subplots: bool = True,
|
||||
column=None,
|
||||
fontsize: int | None = None,
|
||||
rot: int = 0,
|
||||
grid: bool = True,
|
||||
ax=None,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
layout=None,
|
||||
sharex: bool = False,
|
||||
sharey: bool = True,
|
||||
**kwds,
|
||||
):
|
||||
if subplots is True:
|
||||
naxes = len(grouped)
|
||||
fig, axes = create_subplots(
|
||||
naxes=naxes,
|
||||
squeeze=False,
|
||||
ax=ax,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
)
|
||||
axes = flatten_axes(axes)
|
||||
|
||||
ret = pd.Series(dtype=object)
|
||||
|
||||
for (key, group), ax in zip(grouped, axes):
|
||||
d = group.boxplot(
|
||||
ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
|
||||
)
|
||||
ax.set_title(pprint_thing(key))
|
||||
ret.loc[key] = d
|
||||
maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
|
||||
else:
|
||||
keys, frames = zip(*grouped)
|
||||
if grouped.axis == 0:
|
||||
df = pd.concat(frames, keys=keys, axis=1)
|
||||
elif len(frames) > 1:
|
||||
df = frames[0].join(frames[1::])
|
||||
else:
|
||||
df = frames[0]
|
||||
|
||||
# GH 16748, DataFrameGroupby fails when subplots=False and `column` argument
|
||||
# is assigned, and in this case, since `df` here becomes MI after groupby,
|
||||
# so we need to couple the keys (grouped values) and column (original df
|
||||
# column) together to search for subset to plot
|
||||
if column is not None:
|
||||
column = com.convert_to_list_like(column)
|
||||
multi_key = pd.MultiIndex.from_product([keys, column])
|
||||
column = list(multi_key.values)
|
||||
ret = df.boxplot(
|
||||
column=column,
|
||||
fontsize=fontsize,
|
||||
rot=rot,
|
||||
grid=grid,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
**kwds,
|
||||
)
|
||||
return ret
|
File diff suppressed because it is too large
Load Diff
2125
lib/python3.13/site-packages/pandas/plotting/_matplotlib/core.py
Normal file
2125
lib/python3.13/site-packages/pandas/plotting/_matplotlib/core.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.missing import remove_na_arraylike
|
||||
|
||||
from pandas import (
|
||||
MultiIndex,
|
||||
concat,
|
||||
)
|
||||
|
||||
from pandas.plotting._matplotlib.misc import unpack_single_str_list
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from pandas._typing import IndexLabel
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
|
||||
def create_iter_data_given_by(
|
||||
data: DataFrame, kind: str = "hist"
|
||||
) -> dict[Hashable, DataFrame | Series]:
|
||||
"""
|
||||
Create data for iteration given `by` is assigned or not, and it is only
|
||||
used in both hist and boxplot.
|
||||
|
||||
If `by` is assigned, return a dictionary of DataFrames in which the key of
|
||||
dictionary is the values in groups.
|
||||
If `by` is not assigned, return input as is, and this preserves current
|
||||
status of iter_data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : reformatted grouped data from `_compute_plot_data` method.
|
||||
kind : str, plot kind. This function is only used for `hist` and `box` plots.
|
||||
|
||||
Returns
|
||||
-------
|
||||
iter_data : DataFrame or Dictionary of DataFrames
|
||||
|
||||
Examples
|
||||
--------
|
||||
If `by` is assigned:
|
||||
|
||||
>>> import numpy as np
|
||||
>>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')]
|
||||
>>> mi = pd.MultiIndex.from_tuples(tuples)
|
||||
>>> value = [[1, 3, np.nan, np.nan],
|
||||
... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]]
|
||||
>>> data = pd.DataFrame(value, columns=mi)
|
||||
>>> create_iter_data_given_by(data)
|
||||
{'h1': h1
|
||||
a b
|
||||
0 1.0 3.0
|
||||
1 3.0 4.0
|
||||
2 NaN NaN, 'h2': h2
|
||||
a b
|
||||
0 NaN NaN
|
||||
1 NaN NaN
|
||||
2 5.0 6.0}
|
||||
"""
|
||||
|
||||
# For `hist` plot, before transformation, the values in level 0 are values
|
||||
# in groups and subplot titles, and later used for column subselection and
|
||||
# iteration; For `box` plot, values in level 1 are column names to show,
|
||||
# and are used for iteration and as subplots titles.
|
||||
if kind == "hist":
|
||||
level = 0
|
||||
else:
|
||||
level = 1
|
||||
|
||||
# Select sub-columns based on the value of level of MI, and if `by` is
|
||||
# assigned, data must be a MI DataFrame
|
||||
assert isinstance(data.columns, MultiIndex)
|
||||
return {
|
||||
col: data.loc[:, data.columns.get_level_values(level) == col]
|
||||
for col in data.columns.levels[level]
|
||||
}
|
||||
|
||||
|
||||
def reconstruct_data_with_by(
|
||||
data: DataFrame, by: IndexLabel, cols: IndexLabel
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Internal function to group data, and reassign multiindex column names onto the
|
||||
result in order to let grouped data be used in _compute_plot_data method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Original DataFrame to plot
|
||||
by : grouped `by` parameter selected by users
|
||||
cols : columns of data set (excluding columns used in `by`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Output is the reconstructed DataFrame with MultiIndex columns. The first level
|
||||
of MI is unique values of groups, and second level of MI is the columns
|
||||
selected by users.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]}
|
||||
>>> df = pd.DataFrame(d)
|
||||
>>> reconstruct_data_with_by(df, by='h', cols=['a', 'b'])
|
||||
h1 h2
|
||||
a b a b
|
||||
0 1.0 3.0 NaN NaN
|
||||
1 3.0 4.0 NaN NaN
|
||||
2 NaN NaN 5.0 6.0
|
||||
"""
|
||||
by_modified = unpack_single_str_list(by)
|
||||
grouped = data.groupby(by_modified)
|
||||
|
||||
data_list = []
|
||||
for key, group in grouped:
|
||||
# error: List item 1 has incompatible type "Union[Hashable,
|
||||
# Sequence[Hashable]]"; expected "Iterable[Hashable]"
|
||||
columns = MultiIndex.from_product([[key], cols]) # type: ignore[list-item]
|
||||
sub_group = group[cols]
|
||||
sub_group.columns = columns
|
||||
data_list.append(sub_group)
|
||||
|
||||
data = concat(data_list, axis=1)
|
||||
return data
|
||||
|
||||
|
||||
def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
|
||||
"""Internal function to reformat y given `by` is applied or not for hist plot.
|
||||
|
||||
If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
|
||||
will take place and input y is multi-dimensional array.
|
||||
"""
|
||||
if by is not None and len(y.shape) > 1:
|
||||
return np.array([remove_na_arraylike(col) for col in y.T]).T
|
||||
return remove_na_arraylike(y)
|
581
lib/python3.13/site-packages/pandas/plotting/_matplotlib/hist.py
Normal file
581
lib/python3.13/site-packages/pandas/plotting/_matplotlib/hist.py
Normal file
@ -0,0 +1,581 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
final,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.common import (
|
||||
is_integer,
|
||||
is_list_like,
|
||||
)
|
||||
from pandas.core.dtypes.generic import (
|
||||
ABCDataFrame,
|
||||
ABCIndex,
|
||||
)
|
||||
from pandas.core.dtypes.missing import (
|
||||
isna,
|
||||
remove_na_arraylike,
|
||||
)
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.core import (
|
||||
LinePlot,
|
||||
MPLPlot,
|
||||
)
|
||||
from pandas.plotting._matplotlib.groupby import (
|
||||
create_iter_data_given_by,
|
||||
reformat_hist_y_given_by,
|
||||
)
|
||||
from pandas.plotting._matplotlib.misc import unpack_single_str_list
|
||||
from pandas.plotting._matplotlib.tools import (
|
||||
create_subplots,
|
||||
flatten_axes,
|
||||
maybe_adjust_figure,
|
||||
set_ticks_props,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pandas._typing import PlottingOrientation
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
|
||||
class HistPlot(LinePlot):
|
||||
@property
|
||||
def _kind(self) -> Literal["hist", "kde"]:
|
||||
return "hist"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
bins: int | np.ndarray | list[np.ndarray] = 10,
|
||||
bottom: int | np.ndarray = 0,
|
||||
*,
|
||||
range=None,
|
||||
weights=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_list_like(bottom):
|
||||
bottom = np.array(bottom)
|
||||
self.bottom = bottom
|
||||
|
||||
self._bin_range = range
|
||||
self.weights = weights
|
||||
|
||||
self.xlabel = kwargs.get("xlabel")
|
||||
self.ylabel = kwargs.get("ylabel")
|
||||
# Do not call LinePlot.__init__ which may fill nan
|
||||
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
|
||||
|
||||
self.bins = self._adjust_bins(bins)
|
||||
|
||||
def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
|
||||
if is_integer(bins):
|
||||
if self.by is not None:
|
||||
by_modified = unpack_single_str_list(self.by)
|
||||
grouped = self.data.groupby(by_modified)[self.columns]
|
||||
bins = [self._calculate_bins(group, bins) for key, group in grouped]
|
||||
else:
|
||||
bins = self._calculate_bins(self.data, bins)
|
||||
return bins
|
||||
|
||||
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
|
||||
"""Calculate bins given data"""
|
||||
nd_values = data.infer_objects(copy=False)._get_numeric_data()
|
||||
values = np.ravel(nd_values)
|
||||
values = values[~isna(values)]
|
||||
|
||||
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
|
||||
return bins
|
||||
|
||||
# error: Signature of "_plot" incompatible with supertype "LinePlot"
|
||||
@classmethod
|
||||
def _plot( # type: ignore[override]
|
||||
cls,
|
||||
ax: Axes,
|
||||
y: np.ndarray,
|
||||
style=None,
|
||||
bottom: int | np.ndarray = 0,
|
||||
column_num: int = 0,
|
||||
stacking_id=None,
|
||||
*,
|
||||
bins,
|
||||
**kwds,
|
||||
):
|
||||
if column_num == 0:
|
||||
cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
|
||||
|
||||
base = np.zeros(len(bins) - 1)
|
||||
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
|
||||
# ignore style
|
||||
n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
|
||||
cls._update_stacker(ax, stacking_id, n)
|
||||
return patches
|
||||
|
||||
def _make_plot(self, fig: Figure) -> None:
|
||||
colors = self._get_colors()
|
||||
stacking_id = self._get_stacking_id()
|
||||
|
||||
# Re-create iterated data if `by` is assigned by users
|
||||
data = (
|
||||
create_iter_data_given_by(self.data, self._kind)
|
||||
if self.by is not None
|
||||
else self.data
|
||||
)
|
||||
|
||||
# error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
|
||||
# type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
|
||||
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
|
||||
ax = self._get_ax(i)
|
||||
|
||||
kwds = self.kwds.copy()
|
||||
if self.color is not None:
|
||||
kwds["color"] = self.color
|
||||
|
||||
label = pprint_thing(label)
|
||||
label = self._mark_right_label(label, index=i)
|
||||
kwds["label"] = label
|
||||
|
||||
style, kwds = self._apply_style_colors(colors, kwds, i, label)
|
||||
if style is not None:
|
||||
kwds["style"] = style
|
||||
|
||||
self._make_plot_keywords(kwds, y)
|
||||
|
||||
# the bins is multi-dimension array now and each plot need only 1-d and
|
||||
# when by is applied, label should be columns that are grouped
|
||||
if self.by is not None:
|
||||
kwds["bins"] = kwds["bins"][i]
|
||||
kwds["label"] = self.columns
|
||||
kwds.pop("color")
|
||||
|
||||
if self.weights is not None:
|
||||
kwds["weights"] = type(self)._get_column_weights(self.weights, i, y)
|
||||
|
||||
y = reformat_hist_y_given_by(y, self.by)
|
||||
|
||||
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
|
||||
|
||||
# when by is applied, show title for subplots to know which group it is
|
||||
if self.by is not None:
|
||||
ax.set_title(pprint_thing(label))
|
||||
|
||||
self._append_legend_handles_labels(artists[0], label)
|
||||
|
||||
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
|
||||
"""merge BoxPlot/KdePlot properties to passed kwds"""
|
||||
# y is required for KdePlot
|
||||
kwds["bottom"] = self.bottom
|
||||
kwds["bins"] = self.bins
|
||||
|
||||
@final
|
||||
@staticmethod
|
||||
def _get_column_weights(weights, i: int, y):
|
||||
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
|
||||
# and each sub-array (10,) will be called in each iteration. If users only
|
||||
# provide 1D array, we assume the same weights is used for all iterations
|
||||
if weights is not None:
|
||||
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
|
||||
try:
|
||||
weights = weights[:, i]
|
||||
except IndexError as err:
|
||||
raise ValueError(
|
||||
"weights must have the same shape as data, "
|
||||
"or be a single column"
|
||||
) from err
|
||||
weights = weights[~isna(y)]
|
||||
return weights
|
||||
|
||||
def _post_plot_logic(self, ax: Axes, data) -> None:
|
||||
if self.orientation == "horizontal":
|
||||
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
|
||||
# type "Hashable"; expected "str"
|
||||
ax.set_xlabel(
|
||||
"Frequency"
|
||||
if self.xlabel is None
|
||||
else self.xlabel # type: ignore[arg-type]
|
||||
)
|
||||
ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
|
||||
else:
|
||||
ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
|
||||
ax.set_ylabel(
|
||||
"Frequency"
|
||||
if self.ylabel is None
|
||||
else self.ylabel # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@property
|
||||
def orientation(self) -> PlottingOrientation:
|
||||
if self.kwds.get("orientation", None) == "horizontal":
|
||||
return "horizontal"
|
||||
else:
|
||||
return "vertical"
|
||||
|
||||
|
||||
class KdePlot(HistPlot):
|
||||
@property
|
||||
def _kind(self) -> Literal["kde"]:
|
||||
return "kde"
|
||||
|
||||
@property
|
||||
def orientation(self) -> Literal["vertical"]:
|
||||
return "vertical"
|
||||
|
||||
def __init__(
|
||||
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
|
||||
) -> None:
|
||||
# Do not call LinePlot.__init__ which may fill nan
|
||||
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
|
||||
self.bw_method = bw_method
|
||||
self.ind = ind
|
||||
self.weights = weights
|
||||
|
||||
@staticmethod
|
||||
def _get_ind(y: np.ndarray, ind):
|
||||
if ind is None:
|
||||
# np.nanmax() and np.nanmin() ignores the missing values
|
||||
sample_range = np.nanmax(y) - np.nanmin(y)
|
||||
ind = np.linspace(
|
||||
np.nanmin(y) - 0.5 * sample_range,
|
||||
np.nanmax(y) + 0.5 * sample_range,
|
||||
1000,
|
||||
)
|
||||
elif is_integer(ind):
|
||||
sample_range = np.nanmax(y) - np.nanmin(y)
|
||||
ind = np.linspace(
|
||||
np.nanmin(y) - 0.5 * sample_range,
|
||||
np.nanmax(y) + 0.5 * sample_range,
|
||||
ind,
|
||||
)
|
||||
return ind
|
||||
|
||||
@classmethod
|
||||
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
|
||||
def _plot( # type: ignore[override]
|
||||
cls,
|
||||
ax: Axes,
|
||||
y: np.ndarray,
|
||||
style=None,
|
||||
bw_method=None,
|
||||
ind=None,
|
||||
column_num=None,
|
||||
stacking_id: int | None = None,
|
||||
**kwds,
|
||||
):
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
y = remove_na_arraylike(y)
|
||||
gkde = gaussian_kde(y, bw_method=bw_method)
|
||||
|
||||
y = gkde.evaluate(ind)
|
||||
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
|
||||
return lines
|
||||
|
||||
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
|
||||
kwds["bw_method"] = self.bw_method
|
||||
kwds["ind"] = type(self)._get_ind(y, ind=self.ind)
|
||||
|
||||
def _post_plot_logic(self, ax: Axes, data) -> None:
|
||||
ax.set_ylabel("Density")
|
||||
|
||||
|
||||
def _grouped_plot(
|
||||
plotf,
|
||||
data: Series | DataFrame,
|
||||
column=None,
|
||||
by=None,
|
||||
numeric_only: bool = True,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
sharex: bool = True,
|
||||
sharey: bool = True,
|
||||
layout=None,
|
||||
rot: float = 0,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
# error: Non-overlapping equality check (left operand type: "Optional[Tuple[float,
|
||||
# float]]", right operand type: "Literal['default']")
|
||||
if figsize == "default": # type: ignore[comparison-overlap]
|
||||
# allowed to specify mpl default with 'default'
|
||||
raise ValueError(
|
||||
"figsize='default' is no longer supported. "
|
||||
"Specify figure size by tuple instead"
|
||||
)
|
||||
|
||||
grouped = data.groupby(by)
|
||||
if column is not None:
|
||||
grouped = grouped[column]
|
||||
|
||||
naxes = len(grouped)
|
||||
fig, axes = create_subplots(
|
||||
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
|
||||
)
|
||||
|
||||
_axes = flatten_axes(axes)
|
||||
|
||||
for i, (key, group) in enumerate(grouped):
|
||||
ax = _axes[i]
|
||||
if numeric_only and isinstance(group, ABCDataFrame):
|
||||
group = group._get_numeric_data()
|
||||
plotf(group, ax, **kwargs)
|
||||
ax.set_title(pprint_thing(key))
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def _grouped_hist(
|
||||
data: Series | DataFrame,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
bins: int = 50,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
layout=None,
|
||||
sharex: bool = False,
|
||||
sharey: bool = False,
|
||||
rot: float = 90,
|
||||
grid: bool = True,
|
||||
xlabelsize: int | None = None,
|
||||
xrot=None,
|
||||
ylabelsize: int | None = None,
|
||||
yrot=None,
|
||||
legend: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Grouped histogram
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Series/DataFrame
|
||||
column : object, optional
|
||||
by : object, optional
|
||||
ax : axes, optional
|
||||
bins : int, default 50
|
||||
figsize : tuple, optional
|
||||
layout : optional
|
||||
sharex : bool, default False
|
||||
sharey : bool, default False
|
||||
rot : float, default 90
|
||||
grid : bool, default True
|
||||
legend: : bool, default False
|
||||
kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
|
||||
|
||||
Returns
|
||||
-------
|
||||
collection of Matplotlib Axes
|
||||
"""
|
||||
if legend:
|
||||
assert "label" not in kwargs
|
||||
if data.ndim == 1:
|
||||
kwargs["label"] = data.name
|
||||
elif column is None:
|
||||
kwargs["label"] = data.columns
|
||||
else:
|
||||
kwargs["label"] = column
|
||||
|
||||
def plot_group(group, ax) -> None:
|
||||
ax.hist(group.dropna().values, bins=bins, **kwargs)
|
||||
if legend:
|
||||
ax.legend()
|
||||
|
||||
if xrot is None:
|
||||
xrot = rot
|
||||
|
||||
fig, axes = _grouped_plot(
|
||||
plot_group,
|
||||
data,
|
||||
column=column,
|
||||
by=by,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
rot=rot,
|
||||
)
|
||||
|
||||
set_ticks_props(
|
||||
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
||||
)
|
||||
|
||||
maybe_adjust_figure(
|
||||
fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
|
||||
)
|
||||
return axes
|
||||
|
||||
|
||||
def hist_series(
|
||||
self: Series,
|
||||
by=None,
|
||||
ax=None,
|
||||
grid: bool = True,
|
||||
xlabelsize: int | None = None,
|
||||
xrot=None,
|
||||
ylabelsize: int | None = None,
|
||||
yrot=None,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
bins: int = 10,
|
||||
legend: bool = False,
|
||||
**kwds,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if legend and "label" in kwds:
|
||||
raise ValueError("Cannot use both legend and label")
|
||||
|
||||
if by is None:
|
||||
if kwds.get("layout", None) is not None:
|
||||
raise ValueError("The 'layout' keyword is not supported when 'by' is None")
|
||||
# hack until the plotting interface is a bit more unified
|
||||
fig = kwds.pop(
|
||||
"figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
|
||||
)
|
||||
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
|
||||
fig.set_size_inches(*figsize, forward=True)
|
||||
if ax is None:
|
||||
ax = fig.gca()
|
||||
elif ax.get_figure() != fig:
|
||||
raise AssertionError("passed axis not bound to passed figure")
|
||||
values = self.dropna().values
|
||||
if legend:
|
||||
kwds["label"] = self.name
|
||||
ax.hist(values, bins=bins, **kwds)
|
||||
if legend:
|
||||
ax.legend()
|
||||
ax.grid(grid)
|
||||
axes = np.array([ax])
|
||||
|
||||
# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
|
||||
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
|
||||
set_ticks_props(
|
||||
axes, # type: ignore[arg-type]
|
||||
xlabelsize=xlabelsize,
|
||||
xrot=xrot,
|
||||
ylabelsize=ylabelsize,
|
||||
yrot=yrot,
|
||||
)
|
||||
|
||||
else:
|
||||
if "figure" in kwds:
|
||||
raise ValueError(
|
||||
"Cannot pass 'figure' when using the "
|
||||
"'by' argument, since a new 'Figure' instance will be created"
|
||||
)
|
||||
axes = _grouped_hist(
|
||||
self,
|
||||
by=by,
|
||||
ax=ax,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
bins=bins,
|
||||
xlabelsize=xlabelsize,
|
||||
xrot=xrot,
|
||||
ylabelsize=ylabelsize,
|
||||
yrot=yrot,
|
||||
legend=legend,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
if hasattr(axes, "ndim"):
|
||||
if axes.ndim == 1 and len(axes) == 1:
|
||||
return axes[0]
|
||||
return axes
|
||||
|
||||
|
||||
def hist_frame(
|
||||
data: DataFrame,
|
||||
column=None,
|
||||
by=None,
|
||||
grid: bool = True,
|
||||
xlabelsize: int | None = None,
|
||||
xrot=None,
|
||||
ylabelsize: int | None = None,
|
||||
yrot=None,
|
||||
ax=None,
|
||||
sharex: bool = False,
|
||||
sharey: bool = False,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
layout=None,
|
||||
bins: int = 10,
|
||||
legend: bool = False,
|
||||
**kwds,
|
||||
):
|
||||
if legend and "label" in kwds:
|
||||
raise ValueError("Cannot use both legend and label")
|
||||
if by is not None:
|
||||
axes = _grouped_hist(
|
||||
data,
|
||||
column=column,
|
||||
by=by,
|
||||
ax=ax,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
layout=layout,
|
||||
bins=bins,
|
||||
xlabelsize=xlabelsize,
|
||||
xrot=xrot,
|
||||
ylabelsize=ylabelsize,
|
||||
yrot=yrot,
|
||||
legend=legend,
|
||||
**kwds,
|
||||
)
|
||||
return axes
|
||||
|
||||
if column is not None:
|
||||
if not isinstance(column, (list, np.ndarray, ABCIndex)):
|
||||
column = [column]
|
||||
data = data[column]
|
||||
# GH32590
|
||||
data = data.select_dtypes(
|
||||
include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
|
||||
)
|
||||
naxes = len(data.columns)
|
||||
|
||||
if naxes == 0:
|
||||
raise ValueError(
|
||||
"hist method requires numerical or datetime columns, nothing to plot."
|
||||
)
|
||||
|
||||
fig, axes = create_subplots(
|
||||
naxes=naxes,
|
||||
ax=ax,
|
||||
squeeze=False,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
)
|
||||
_axes = flatten_axes(axes)
|
||||
|
||||
can_set_label = "label" not in kwds
|
||||
|
||||
for i, col in enumerate(data.columns):
|
||||
ax = _axes[i]
|
||||
if legend and can_set_label:
|
||||
kwds["label"] = col
|
||||
ax.hist(data[col].dropna().values, bins=bins, **kwds)
|
||||
ax.set_title(col)
|
||||
ax.grid(grid)
|
||||
if legend:
|
||||
ax.legend()
|
||||
|
||||
set_ticks_props(
|
||||
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
||||
)
|
||||
maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
|
||||
|
||||
return axes
|
481
lib/python3.13/site-packages/pandas/plotting/_matplotlib/misc.py
Normal file
481
lib/python3.13/site-packages/pandas/plotting/_matplotlib/misc.py
Normal file
@ -0,0 +1,481 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from matplotlib import patches
|
||||
import matplotlib.lines as mlines
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.missing import notna
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.style import get_standard_colors
|
||||
from pandas.plotting._matplotlib.tools import (
|
||||
create_subplots,
|
||||
do_adjust_figure,
|
||||
maybe_adjust_figure,
|
||||
set_ticks_props,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Index,
|
||||
Series,
|
||||
)
|
||||
|
||||
|
||||
def scatter_matrix(
|
||||
frame: DataFrame,
|
||||
alpha: float = 0.5,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
ax=None,
|
||||
grid: bool = False,
|
||||
diagonal: str = "hist",
|
||||
marker: str = ".",
|
||||
density_kwds=None,
|
||||
hist_kwds=None,
|
||||
range_padding: float = 0.05,
|
||||
**kwds,
|
||||
):
|
||||
df = frame._get_numeric_data()
|
||||
n = df.columns.size
|
||||
naxes = n * n
|
||||
fig, axes = create_subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False)
|
||||
|
||||
# no gaps between subplots
|
||||
maybe_adjust_figure(fig, wspace=0, hspace=0)
|
||||
|
||||
mask = notna(df)
|
||||
|
||||
marker = _get_marker_compat(marker)
|
||||
|
||||
hist_kwds = hist_kwds or {}
|
||||
density_kwds = density_kwds or {}
|
||||
|
||||
# GH 14855
|
||||
kwds.setdefault("edgecolors", "none")
|
||||
|
||||
boundaries_list = []
|
||||
for a in df.columns:
|
||||
values = df[a].values[mask[a].values]
|
||||
rmin_, rmax_ = np.min(values), np.max(values)
|
||||
rdelta_ext = (rmax_ - rmin_) * range_padding / 2
|
||||
boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
|
||||
|
||||
for i, a in enumerate(df.columns):
|
||||
for j, b in enumerate(df.columns):
|
||||
ax = axes[i, j]
|
||||
|
||||
if i == j:
|
||||
values = df[a].values[mask[a].values]
|
||||
|
||||
# Deal with the diagonal by drawing a histogram there.
|
||||
if diagonal == "hist":
|
||||
ax.hist(values, **hist_kwds)
|
||||
|
||||
elif diagonal in ("kde", "density"):
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
y = values
|
||||
gkde = gaussian_kde(y)
|
||||
ind = np.linspace(y.min(), y.max(), 1000)
|
||||
ax.plot(ind, gkde.evaluate(ind), **density_kwds)
|
||||
|
||||
ax.set_xlim(boundaries_list[i])
|
||||
|
||||
else:
|
||||
common = (mask[a] & mask[b]).values
|
||||
|
||||
ax.scatter(
|
||||
df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds
|
||||
)
|
||||
|
||||
ax.set_xlim(boundaries_list[j])
|
||||
ax.set_ylim(boundaries_list[i])
|
||||
|
||||
ax.set_xlabel(b)
|
||||
ax.set_ylabel(a)
|
||||
|
||||
if j != 0:
|
||||
ax.yaxis.set_visible(False)
|
||||
if i != n - 1:
|
||||
ax.xaxis.set_visible(False)
|
||||
|
||||
if len(df.columns) > 1:
|
||||
lim1 = boundaries_list[0]
|
||||
locs = axes[0][1].yaxis.get_majorticklocs()
|
||||
locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
|
||||
adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
|
||||
|
||||
lim0 = axes[0][0].get_ylim()
|
||||
adj = adj * (lim0[1] - lim0[0]) + lim0[0]
|
||||
axes[0][0].yaxis.set_ticks(adj)
|
||||
|
||||
if np.all(locs == locs.astype(int)):
|
||||
# if all ticks are int
|
||||
locs = locs.astype(int)
|
||||
axes[0][0].yaxis.set_ticklabels(locs)
|
||||
|
||||
set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
|
||||
|
||||
return axes
|
||||
|
||||
|
||||
def _get_marker_compat(marker):
|
||||
if marker not in mlines.lineMarkers:
|
||||
return "o"
|
||||
return marker
|
||||
|
||||
|
||||
def radviz(
|
||||
frame: DataFrame,
|
||||
class_column,
|
||||
ax: Axes | None = None,
|
||||
color=None,
|
||||
colormap=None,
|
||||
**kwds,
|
||||
) -> Axes:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def normalize(series):
|
||||
a = min(series)
|
||||
b = max(series)
|
||||
return (series - a) / (b - a)
|
||||
|
||||
n = len(frame)
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
class_col = frame[class_column]
|
||||
df = frame.drop(class_column, axis=1).apply(normalize)
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.set_xlim(-1, 1)
|
||||
ax.set_ylim(-1, 1)
|
||||
|
||||
to_plot: dict[Hashable, list[list]] = {}
|
||||
colors = get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
|
||||
for kls in classes:
|
||||
to_plot[kls] = [[], []]
|
||||
|
||||
m = len(frame.columns) - 1
|
||||
s = np.array(
|
||||
[(np.cos(t), np.sin(t)) for t in [2 * np.pi * (i / m) for i in range(m)]]
|
||||
)
|
||||
|
||||
for i in range(n):
|
||||
row = df.iloc[i].values
|
||||
row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
|
||||
y = (s * row_).sum(axis=0) / row.sum()
|
||||
kls = class_col.iat[i]
|
||||
to_plot[kls][0].append(y[0])
|
||||
to_plot[kls][1].append(y[1])
|
||||
|
||||
for i, kls in enumerate(classes):
|
||||
ax.scatter(
|
||||
to_plot[kls][0],
|
||||
to_plot[kls][1],
|
||||
color=colors[i],
|
||||
label=pprint_thing(kls),
|
||||
**kwds,
|
||||
)
|
||||
ax.legend()
|
||||
|
||||
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
|
||||
|
||||
for xy, name in zip(s, df.columns):
|
||||
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
|
||||
|
||||
if xy[0] < 0.0 and xy[1] < 0.0:
|
||||
ax.text(
|
||||
xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small"
|
||||
)
|
||||
elif xy[0] < 0.0 <= xy[1]:
|
||||
ax.text(
|
||||
xy[0] - 0.025,
|
||||
xy[1] + 0.025,
|
||||
name,
|
||||
ha="right",
|
||||
va="bottom",
|
||||
size="small",
|
||||
)
|
||||
elif xy[1] < 0.0 <= xy[0]:
|
||||
ax.text(
|
||||
xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small"
|
||||
)
|
||||
elif xy[0] >= 0.0 and xy[1] >= 0.0:
|
||||
ax.text(
|
||||
xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small"
|
||||
)
|
||||
|
||||
ax.axis("equal")
|
||||
return ax
|
||||
|
||||
|
||||
def andrews_curves(
|
||||
frame: DataFrame,
|
||||
class_column,
|
||||
ax: Axes | None = None,
|
||||
samples: int = 200,
|
||||
color=None,
|
||||
colormap=None,
|
||||
**kwds,
|
||||
) -> Axes:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def function(amplitudes):
|
||||
def f(t):
|
||||
x1 = amplitudes[0]
|
||||
result = x1 / np.sqrt(2.0)
|
||||
|
||||
# Take the rest of the coefficients and resize them
|
||||
# appropriately. Take a copy of amplitudes as otherwise numpy
|
||||
# deletes the element from amplitudes itself.
|
||||
coeffs = np.delete(np.copy(amplitudes), 0)
|
||||
coeffs = np.resize(coeffs, (int((coeffs.size + 1) / 2), 2))
|
||||
|
||||
# Generate the harmonics and arguments for the sin and cos
|
||||
# functions.
|
||||
harmonics = np.arange(0, coeffs.shape[0]) + 1
|
||||
trig_args = np.outer(harmonics, t)
|
||||
|
||||
result += np.sum(
|
||||
coeffs[:, 0, np.newaxis] * np.sin(trig_args)
|
||||
+ coeffs[:, 1, np.newaxis] * np.cos(trig_args),
|
||||
axis=0,
|
||||
)
|
||||
return result
|
||||
|
||||
return f
|
||||
|
||||
n = len(frame)
|
||||
class_col = frame[class_column]
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
df = frame.drop(class_column, axis=1)
|
||||
t = np.linspace(-np.pi, np.pi, samples)
|
||||
used_legends: set[str] = set()
|
||||
|
||||
color_values = get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
colors = dict(zip(classes, color_values))
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.set_xlim(-np.pi, np.pi)
|
||||
for i in range(n):
|
||||
row = df.iloc[i].values
|
||||
f = function(row)
|
||||
y = f(t)
|
||||
kls = class_col.iat[i]
|
||||
label = pprint_thing(kls)
|
||||
if label not in used_legends:
|
||||
used_legends.add(label)
|
||||
ax.plot(t, y, color=colors[kls], label=label, **kwds)
|
||||
else:
|
||||
ax.plot(t, y, color=colors[kls], **kwds)
|
||||
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid()
|
||||
return ax
|
||||
|
||||
|
||||
def bootstrap_plot(
|
||||
series: Series,
|
||||
fig: Figure | None = None,
|
||||
size: int = 50,
|
||||
samples: int = 500,
|
||||
**kwds,
|
||||
) -> Figure:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# TODO: is the failure mentioned below still relevant?
|
||||
# random.sample(ndarray, int) fails on python 3.3, sigh
|
||||
data = list(series.values)
|
||||
samplings = [random.sample(data, size) for _ in range(samples)]
|
||||
|
||||
means = np.array([np.mean(sampling) for sampling in samplings])
|
||||
medians = np.array([np.median(sampling) for sampling in samplings])
|
||||
midranges = np.array(
|
||||
[(min(sampling) + max(sampling)) * 0.5 for sampling in samplings]
|
||||
)
|
||||
if fig is None:
|
||||
fig = plt.figure()
|
||||
x = list(range(samples))
|
||||
axes = []
|
||||
ax1 = fig.add_subplot(2, 3, 1)
|
||||
ax1.set_xlabel("Sample")
|
||||
axes.append(ax1)
|
||||
ax1.plot(x, means, **kwds)
|
||||
ax2 = fig.add_subplot(2, 3, 2)
|
||||
ax2.set_xlabel("Sample")
|
||||
axes.append(ax2)
|
||||
ax2.plot(x, medians, **kwds)
|
||||
ax3 = fig.add_subplot(2, 3, 3)
|
||||
ax3.set_xlabel("Sample")
|
||||
axes.append(ax3)
|
||||
ax3.plot(x, midranges, **kwds)
|
||||
ax4 = fig.add_subplot(2, 3, 4)
|
||||
ax4.set_xlabel("Mean")
|
||||
axes.append(ax4)
|
||||
ax4.hist(means, **kwds)
|
||||
ax5 = fig.add_subplot(2, 3, 5)
|
||||
ax5.set_xlabel("Median")
|
||||
axes.append(ax5)
|
||||
ax5.hist(medians, **kwds)
|
||||
ax6 = fig.add_subplot(2, 3, 6)
|
||||
ax6.set_xlabel("Midrange")
|
||||
axes.append(ax6)
|
||||
ax6.hist(midranges, **kwds)
|
||||
for axis in axes:
|
||||
plt.setp(axis.get_xticklabels(), fontsize=8)
|
||||
plt.setp(axis.get_yticklabels(), fontsize=8)
|
||||
if do_adjust_figure(fig):
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def parallel_coordinates(
|
||||
frame: DataFrame,
|
||||
class_column,
|
||||
cols=None,
|
||||
ax: Axes | None = None,
|
||||
color=None,
|
||||
use_columns: bool = False,
|
||||
xticks=None,
|
||||
colormap=None,
|
||||
axvlines: bool = True,
|
||||
axvlines_kwds=None,
|
||||
sort_labels: bool = False,
|
||||
**kwds,
|
||||
) -> Axes:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if axvlines_kwds is None:
|
||||
axvlines_kwds = {"linewidth": 1, "color": "black"}
|
||||
|
||||
n = len(frame)
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
class_col = frame[class_column]
|
||||
|
||||
if cols is None:
|
||||
df = frame.drop(class_column, axis=1)
|
||||
else:
|
||||
df = frame[cols]
|
||||
|
||||
used_legends: set[str] = set()
|
||||
|
||||
ncols = len(df.columns)
|
||||
|
||||
# determine values to use for xticks
|
||||
x: list[int] | Index
|
||||
if use_columns is True:
|
||||
if not np.all(np.isreal(list(df.columns))):
|
||||
raise ValueError("Columns must be numeric to be used as xticks")
|
||||
x = df.columns
|
||||
elif xticks is not None:
|
||||
if not np.all(np.isreal(xticks)):
|
||||
raise ValueError("xticks specified must be numeric")
|
||||
if len(xticks) != ncols:
|
||||
raise ValueError("Length of xticks must match number of columns")
|
||||
x = xticks
|
||||
else:
|
||||
x = list(range(ncols))
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
color_values = get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
|
||||
if sort_labels:
|
||||
classes = sorted(classes)
|
||||
color_values = sorted(color_values)
|
||||
colors = dict(zip(classes, color_values))
|
||||
|
||||
for i in range(n):
|
||||
y = df.iloc[i].values
|
||||
kls = class_col.iat[i]
|
||||
label = pprint_thing(kls)
|
||||
if label not in used_legends:
|
||||
used_legends.add(label)
|
||||
ax.plot(x, y, color=colors[kls], label=label, **kwds)
|
||||
else:
|
||||
ax.plot(x, y, color=colors[kls], **kwds)
|
||||
|
||||
if axvlines:
|
||||
for i in x:
|
||||
ax.axvline(i, **axvlines_kwds)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(df.columns)
|
||||
ax.set_xlim(x[0], x[-1])
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid()
|
||||
return ax
|
||||
|
||||
|
||||
def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes:
|
||||
# workaround because `c='b'` is hardcoded in matplotlib's scatter method
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
kwds.setdefault("c", plt.rcParams["patch.facecolor"])
|
||||
|
||||
data = series.values
|
||||
y1 = data[:-lag]
|
||||
y2 = data[lag:]
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.set_xlabel("y(t)")
|
||||
ax.set_ylabel(f"y(t + {lag})")
|
||||
ax.scatter(y1, y2, **kwds)
|
||||
return ax
|
||||
|
||||
|
||||
def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwds) -> Axes:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
n = len(series)
|
||||
data = np.asarray(series)
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.set_xlim(1, n)
|
||||
ax.set_ylim(-1.0, 1.0)
|
||||
mean = np.mean(data)
|
||||
c0 = np.sum((data - mean) ** 2) / n
|
||||
|
||||
def r(h):
|
||||
return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / n / c0
|
||||
|
||||
x = np.arange(n) + 1
|
||||
y = [r(loc) for loc in x]
|
||||
z95 = 1.959963984540054
|
||||
z99 = 2.5758293035489004
|
||||
ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey")
|
||||
ax.axhline(y=z95 / np.sqrt(n), color="grey")
|
||||
ax.axhline(y=0.0, color="black")
|
||||
ax.axhline(y=-z95 / np.sqrt(n), color="grey")
|
||||
ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey")
|
||||
ax.set_xlabel("Lag")
|
||||
ax.set_ylabel("Autocorrelation")
|
||||
ax.plot(x, y, **kwds)
|
||||
if "label" in kwds:
|
||||
ax.legend()
|
||||
ax.grid()
|
||||
return ax
|
||||
|
||||
|
||||
def unpack_single_str_list(keys):
|
||||
# GH 42795
|
||||
if isinstance(keys, list) and len(keys) == 1:
|
||||
keys = keys[0]
|
||||
return keys
|
@ -0,0 +1,278 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import (
|
||||
Collection,
|
||||
Iterator,
|
||||
)
|
||||
import itertools
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
cast,
|
||||
)
|
||||
import warnings
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
|
||||
from pandas._typing import MatplotlibColor as Color
|
||||
from pandas.util._exceptions import find_stack_level
|
||||
|
||||
from pandas.core.dtypes.common import is_list_like
|
||||
|
||||
import pandas.core.common as com
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.colors import Colormap
|
||||
|
||||
|
||||
def get_standard_colors(
|
||||
num_colors: int,
|
||||
colormap: Colormap | None = None,
|
||||
color_type: str = "default",
|
||||
color: dict[str, Color] | Color | Collection[Color] | None = None,
|
||||
):
|
||||
"""
|
||||
Get standard colors based on `colormap`, `color_type` or `color` inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_colors : int
|
||||
Minimum number of colors to be returned.
|
||||
Ignored if `color` is a dictionary.
|
||||
colormap : :py:class:`matplotlib.colors.Colormap`, optional
|
||||
Matplotlib colormap.
|
||||
When provided, the resulting colors will be derived from the colormap.
|
||||
color_type : {"default", "random"}, optional
|
||||
Type of colors to derive. Used if provided `color` and `colormap` are None.
|
||||
Ignored if either `color` or `colormap` are not None.
|
||||
color : dict or str or sequence, optional
|
||||
Color(s) to be used for deriving sequence of colors.
|
||||
Can be either be a dictionary, or a single color (single color string,
|
||||
or sequence of floats representing a single color),
|
||||
or a sequence of colors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict or list
|
||||
Standard colors. Can either be a mapping if `color` was a dictionary,
|
||||
or a list of colors with a length of `num_colors` or more.
|
||||
|
||||
Warns
|
||||
-----
|
||||
UserWarning
|
||||
If both `colormap` and `color` are provided.
|
||||
Parameter `color` will override.
|
||||
"""
|
||||
if isinstance(color, dict):
|
||||
return color
|
||||
|
||||
colors = _derive_colors(
|
||||
color=color,
|
||||
colormap=colormap,
|
||||
color_type=color_type,
|
||||
num_colors=num_colors,
|
||||
)
|
||||
|
||||
return list(_cycle_colors(colors, num_colors=num_colors))
|
||||
|
||||
|
||||
def _derive_colors(
|
||||
*,
|
||||
color: Color | Collection[Color] | None,
|
||||
colormap: str | Colormap | None,
|
||||
color_type: str,
|
||||
num_colors: int,
|
||||
) -> list[Color]:
|
||||
"""
|
||||
Derive colors from either `colormap`, `color_type` or `color` inputs.
|
||||
|
||||
Get a list of colors either from `colormap`, or from `color`,
|
||||
or from `color_type` (if both `colormap` and `color` are None).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
color : str or sequence, optional
|
||||
Color(s) to be used for deriving sequence of colors.
|
||||
Can be either be a single color (single color string, or sequence of floats
|
||||
representing a single color), or a sequence of colors.
|
||||
colormap : :py:class:`matplotlib.colors.Colormap`, optional
|
||||
Matplotlib colormap.
|
||||
When provided, the resulting colors will be derived from the colormap.
|
||||
color_type : {"default", "random"}, optional
|
||||
Type of colors to derive. Used if provided `color` and `colormap` are None.
|
||||
Ignored if either `color` or `colormap`` are not None.
|
||||
num_colors : int
|
||||
Number of colors to be extracted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of colors extracted.
|
||||
|
||||
Warns
|
||||
-----
|
||||
UserWarning
|
||||
If both `colormap` and `color` are provided.
|
||||
Parameter `color` will override.
|
||||
"""
|
||||
if color is None and colormap is not None:
|
||||
return _get_colors_from_colormap(colormap, num_colors=num_colors)
|
||||
elif color is not None:
|
||||
if colormap is not None:
|
||||
warnings.warn(
|
||||
"'color' and 'colormap' cannot be used simultaneously. Using 'color'",
|
||||
stacklevel=find_stack_level(),
|
||||
)
|
||||
return _get_colors_from_color(color)
|
||||
else:
|
||||
return _get_colors_from_color_type(color_type, num_colors=num_colors)
|
||||
|
||||
|
||||
def _cycle_colors(colors: list[Color], num_colors: int) -> Iterator[Color]:
|
||||
"""Cycle colors until achieving max of `num_colors` or length of `colors`.
|
||||
|
||||
Extra colors will be ignored by matplotlib if there are more colors
|
||||
than needed and nothing needs to be done here.
|
||||
"""
|
||||
max_colors = max(num_colors, len(colors))
|
||||
yield from itertools.islice(itertools.cycle(colors), max_colors)
|
||||
|
||||
|
||||
def _get_colors_from_colormap(
|
||||
colormap: str | Colormap,
|
||||
num_colors: int,
|
||||
) -> list[Color]:
|
||||
"""Get colors from colormap."""
|
||||
cmap = _get_cmap_instance(colormap)
|
||||
return [cmap(num) for num in np.linspace(0, 1, num=num_colors)]
|
||||
|
||||
|
||||
def _get_cmap_instance(colormap: str | Colormap) -> Colormap:
|
||||
"""Get instance of matplotlib colormap."""
|
||||
if isinstance(colormap, str):
|
||||
cmap = colormap
|
||||
colormap = mpl.colormaps[colormap]
|
||||
if colormap is None:
|
||||
raise ValueError(f"Colormap {cmap} is not recognized")
|
||||
return colormap
|
||||
|
||||
|
||||
def _get_colors_from_color(
|
||||
color: Color | Collection[Color],
|
||||
) -> list[Color]:
|
||||
"""Get colors from user input color."""
|
||||
if len(color) == 0:
|
||||
raise ValueError(f"Invalid color argument: {color}")
|
||||
|
||||
if _is_single_color(color):
|
||||
color = cast(Color, color)
|
||||
return [color]
|
||||
|
||||
color = cast(Collection[Color], color)
|
||||
return list(_gen_list_of_colors_from_iterable(color))
|
||||
|
||||
|
||||
def _is_single_color(color: Color | Collection[Color]) -> bool:
|
||||
"""Check if `color` is a single color, not a sequence of colors.
|
||||
|
||||
Single color is of these kinds:
|
||||
- Named color "red", "C0", "firebrick"
|
||||
- Alias "g"
|
||||
- Sequence of floats, such as (0.1, 0.2, 0.3) or (0.1, 0.2, 0.3, 0.4).
|
||||
|
||||
See Also
|
||||
--------
|
||||
_is_single_string_color
|
||||
"""
|
||||
if isinstance(color, str) and _is_single_string_color(color):
|
||||
# GH #36972
|
||||
return True
|
||||
|
||||
if _is_floats_color(color):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _gen_list_of_colors_from_iterable(color: Collection[Color]) -> Iterator[Color]:
|
||||
"""
|
||||
Yield colors from string of several letters or from collection of colors.
|
||||
"""
|
||||
for x in color:
|
||||
if _is_single_color(x):
|
||||
yield x
|
||||
else:
|
||||
raise ValueError(f"Invalid color {x}")
|
||||
|
||||
|
||||
def _is_floats_color(color: Color | Collection[Color]) -> bool:
|
||||
"""Check if color comprises a sequence of floats representing color."""
|
||||
return bool(
|
||||
is_list_like(color)
|
||||
and (len(color) == 3 or len(color) == 4)
|
||||
and all(isinstance(x, (int, float)) for x in color)
|
||||
)
|
||||
|
||||
|
||||
def _get_colors_from_color_type(color_type: str, num_colors: int) -> list[Color]:
|
||||
"""Get colors from user input color type."""
|
||||
if color_type == "default":
|
||||
return _get_default_colors(num_colors)
|
||||
elif color_type == "random":
|
||||
return _get_random_colors(num_colors)
|
||||
else:
|
||||
raise ValueError("color_type must be either 'default' or 'random'")
|
||||
|
||||
|
||||
def _get_default_colors(num_colors: int) -> list[Color]:
|
||||
"""Get `num_colors` of default colors from matplotlib rc params."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
|
||||
return colors[0:num_colors]
|
||||
|
||||
|
||||
def _get_random_colors(num_colors: int) -> list[Color]:
|
||||
"""Get `num_colors` of random colors."""
|
||||
return [_random_color(num) for num in range(num_colors)]
|
||||
|
||||
|
||||
def _random_color(column: int) -> list[float]:
|
||||
"""Get a random color represented as a list of length 3"""
|
||||
# GH17525 use common._random_state to avoid resetting the seed
|
||||
rs = com.random_state(column)
|
||||
return rs.rand(3).tolist()
|
||||
|
||||
|
||||
def _is_single_string_color(color: Color) -> bool:
|
||||
"""Check if `color` is a single string color.
|
||||
|
||||
Examples of single string colors:
|
||||
- 'r'
|
||||
- 'g'
|
||||
- 'red'
|
||||
- 'green'
|
||||
- 'C3'
|
||||
- 'firebrick'
|
||||
|
||||
Parameters
|
||||
----------
|
||||
color : Color
|
||||
Color string or sequence of floats.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if `color` looks like a valid color.
|
||||
False otherwise.
|
||||
"""
|
||||
conv = matplotlib.colors.ColorConverter()
|
||||
try:
|
||||
# error: Argument 1 to "to_rgba" of "ColorConverter" has incompatible type
|
||||
# "str | Sequence[float]"; expected "tuple[float, float, float] | ..."
|
||||
conv.to_rgba(color) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
return False
|
||||
else:
|
||||
return True
|
@ -0,0 +1,370 @@
|
||||
# TODO: Use the fact that axis can have units to simplify the process
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
cast,
|
||||
)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas._libs.tslibs import (
|
||||
BaseOffset,
|
||||
Period,
|
||||
to_offset,
|
||||
)
|
||||
from pandas._libs.tslibs.dtypes import (
|
||||
OFFSET_TO_PERIOD_FREQSTR,
|
||||
FreqGroup,
|
||||
)
|
||||
|
||||
from pandas.core.dtypes.generic import (
|
||||
ABCDatetimeIndex,
|
||||
ABCPeriodIndex,
|
||||
ABCTimedeltaIndex,
|
||||
)
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.converter import (
|
||||
TimeSeries_DateFormatter,
|
||||
TimeSeries_DateLocator,
|
||||
TimeSeries_TimedeltaFormatter,
|
||||
)
|
||||
from pandas.tseries.frequencies import (
|
||||
get_period_alias,
|
||||
is_subperiod,
|
||||
is_superperiod,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
from pandas._typing import NDFrameT
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
DatetimeIndex,
|
||||
Index,
|
||||
PeriodIndex,
|
||||
Series,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Plotting functions and monkey patches
|
||||
|
||||
|
||||
def maybe_resample(series: Series, ax: Axes, kwargs: dict[str, Any]):
|
||||
# resample against axes freq if necessary
|
||||
|
||||
if "how" in kwargs:
|
||||
raise ValueError(
|
||||
"'how' is not a valid keyword for plotting functions. If plotting "
|
||||
"multiple objects on shared axes, resample manually first."
|
||||
)
|
||||
|
||||
freq, ax_freq = _get_freq(ax, series)
|
||||
|
||||
if freq is None: # pragma: no cover
|
||||
raise ValueError("Cannot use dynamic axis without frequency info")
|
||||
|
||||
# Convert DatetimeIndex to PeriodIndex
|
||||
if isinstance(series.index, ABCDatetimeIndex):
|
||||
series = series.to_period(freq=freq)
|
||||
|
||||
if ax_freq is not None and freq != ax_freq:
|
||||
if is_superperiod(freq, ax_freq): # upsample input
|
||||
series = series.copy()
|
||||
# error: "Index" has no attribute "asfreq"
|
||||
series.index = series.index.asfreq( # type: ignore[attr-defined]
|
||||
ax_freq, how="s"
|
||||
)
|
||||
freq = ax_freq
|
||||
elif _is_sup(freq, ax_freq): # one is weekly
|
||||
# Resampling with PeriodDtype is deprecated, so we convert to
|
||||
# DatetimeIndex, resample, then convert back.
|
||||
ser_ts = series.to_timestamp()
|
||||
ser_d = ser_ts.resample("D").last().dropna()
|
||||
ser_freq = ser_d.resample(ax_freq).last().dropna()
|
||||
series = ser_freq.to_period(ax_freq)
|
||||
freq = ax_freq
|
||||
elif is_subperiod(freq, ax_freq) or _is_sub(freq, ax_freq):
|
||||
_upsample_others(ax, freq, kwargs)
|
||||
else: # pragma: no cover
|
||||
raise ValueError("Incompatible frequency conversion")
|
||||
return freq, series
|
||||
|
||||
|
||||
def _is_sub(f1: str, f2: str) -> bool:
|
||||
return (f1.startswith("W") and is_subperiod("D", f2)) or (
|
||||
f2.startswith("W") and is_subperiod(f1, "D")
|
||||
)
|
||||
|
||||
|
||||
def _is_sup(f1: str, f2: str) -> bool:
|
||||
return (f1.startswith("W") and is_superperiod("D", f2)) or (
|
||||
f2.startswith("W") and is_superperiod(f1, "D")
|
||||
)
|
||||
|
||||
|
||||
def _upsample_others(ax: Axes, freq: BaseOffset, kwargs: dict[str, Any]) -> None:
|
||||
legend = ax.get_legend()
|
||||
lines, labels = _replot_ax(ax, freq)
|
||||
_replot_ax(ax, freq)
|
||||
|
||||
other_ax = None
|
||||
if hasattr(ax, "left_ax"):
|
||||
other_ax = ax.left_ax
|
||||
if hasattr(ax, "right_ax"):
|
||||
other_ax = ax.right_ax
|
||||
|
||||
if other_ax is not None:
|
||||
rlines, rlabels = _replot_ax(other_ax, freq)
|
||||
lines.extend(rlines)
|
||||
labels.extend(rlabels)
|
||||
|
||||
if legend is not None and kwargs.get("legend", True) and len(lines) > 0:
|
||||
title: str | None = legend.get_title().get_text()
|
||||
if title == "None":
|
||||
title = None
|
||||
ax.legend(lines, labels, loc="best", title=title)
|
||||
|
||||
|
||||
def _replot_ax(ax: Axes, freq: BaseOffset):
|
||||
data = getattr(ax, "_plot_data", None)
|
||||
|
||||
# clear current axes and data
|
||||
# TODO #54485
|
||||
ax._plot_data = [] # type: ignore[attr-defined]
|
||||
ax.clear()
|
||||
|
||||
decorate_axes(ax, freq)
|
||||
|
||||
lines = []
|
||||
labels = []
|
||||
if data is not None:
|
||||
for series, plotf, kwds in data:
|
||||
series = series.copy()
|
||||
idx = series.index.asfreq(freq, how="S")
|
||||
series.index = idx
|
||||
# TODO #54485
|
||||
ax._plot_data.append((series, plotf, kwds)) # type: ignore[attr-defined]
|
||||
|
||||
# for tsplot
|
||||
if isinstance(plotf, str):
|
||||
from pandas.plotting._matplotlib import PLOT_CLASSES
|
||||
|
||||
plotf = PLOT_CLASSES[plotf]._plot
|
||||
|
||||
lines.append(plotf(ax, series.index._mpl_repr(), series.values, **kwds)[0])
|
||||
labels.append(pprint_thing(series.name))
|
||||
|
||||
return lines, labels
|
||||
|
||||
|
||||
def decorate_axes(ax: Axes, freq: BaseOffset) -> None:
|
||||
"""Initialize axes for time-series plotting"""
|
||||
if not hasattr(ax, "_plot_data"):
|
||||
# TODO #54485
|
||||
ax._plot_data = [] # type: ignore[attr-defined]
|
||||
|
||||
# TODO #54485
|
||||
ax.freq = freq # type: ignore[attr-defined]
|
||||
xaxis = ax.get_xaxis()
|
||||
# TODO #54485
|
||||
xaxis.freq = freq # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _get_ax_freq(ax: Axes):
|
||||
"""
|
||||
Get the freq attribute of the ax object if set.
|
||||
Also checks shared axes (eg when using secondary yaxis, sharex=True
|
||||
or twinx)
|
||||
"""
|
||||
ax_freq = getattr(ax, "freq", None)
|
||||
if ax_freq is None:
|
||||
# check for left/right ax in case of secondary yaxis
|
||||
if hasattr(ax, "left_ax"):
|
||||
ax_freq = getattr(ax.left_ax, "freq", None)
|
||||
elif hasattr(ax, "right_ax"):
|
||||
ax_freq = getattr(ax.right_ax, "freq", None)
|
||||
if ax_freq is None:
|
||||
# check if a shared ax (sharex/twinx) has already freq set
|
||||
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
|
||||
if len(shared_axes) > 1:
|
||||
for shared_ax in shared_axes:
|
||||
ax_freq = getattr(shared_ax, "freq", None)
|
||||
if ax_freq is not None:
|
||||
break
|
||||
return ax_freq
|
||||
|
||||
|
||||
def _get_period_alias(freq: timedelta | BaseOffset | str) -> str | None:
|
||||
if isinstance(freq, BaseOffset):
|
||||
freqstr = freq.name
|
||||
else:
|
||||
freqstr = to_offset(freq, is_period=True).rule_code
|
||||
|
||||
return get_period_alias(freqstr)
|
||||
|
||||
|
||||
def _get_freq(ax: Axes, series: Series):
|
||||
# get frequency from data
|
||||
freq = getattr(series.index, "freq", None)
|
||||
if freq is None:
|
||||
freq = getattr(series.index, "inferred_freq", None)
|
||||
freq = to_offset(freq, is_period=True)
|
||||
|
||||
ax_freq = _get_ax_freq(ax)
|
||||
|
||||
# use axes freq if no data freq
|
||||
if freq is None:
|
||||
freq = ax_freq
|
||||
|
||||
# get the period frequency
|
||||
freq = _get_period_alias(freq)
|
||||
return freq, ax_freq
|
||||
|
||||
|
||||
def use_dynamic_x(ax: Axes, data: DataFrame | Series) -> bool:
|
||||
freq = _get_index_freq(data.index)
|
||||
ax_freq = _get_ax_freq(ax)
|
||||
|
||||
if freq is None: # convert irregular if axes has freq info
|
||||
freq = ax_freq
|
||||
# do not use tsplot if irregular was plotted first
|
||||
elif (ax_freq is None) and (len(ax.get_lines()) > 0):
|
||||
return False
|
||||
|
||||
if freq is None:
|
||||
return False
|
||||
|
||||
freq_str = _get_period_alias(freq)
|
||||
|
||||
if freq_str is None:
|
||||
return False
|
||||
|
||||
# FIXME: hack this for 0.10.1, creating more technical debt...sigh
|
||||
if isinstance(data.index, ABCDatetimeIndex):
|
||||
# error: "BaseOffset" has no attribute "_period_dtype_code"
|
||||
freq_str = OFFSET_TO_PERIOD_FREQSTR.get(freq_str, freq_str)
|
||||
base = to_offset(
|
||||
freq_str, is_period=True
|
||||
)._period_dtype_code # type: ignore[attr-defined]
|
||||
x = data.index
|
||||
if base <= FreqGroup.FR_DAY.value:
|
||||
return x[:1].is_normalized
|
||||
period = Period(x[0], freq_str)
|
||||
assert isinstance(period, Period)
|
||||
return period.to_timestamp().tz_localize(x.tz) == x[0]
|
||||
return True
|
||||
|
||||
|
||||
def _get_index_freq(index: Index) -> BaseOffset | None:
|
||||
freq = getattr(index, "freq", None)
|
||||
if freq is None:
|
||||
freq = getattr(index, "inferred_freq", None)
|
||||
if freq == "B":
|
||||
# error: "Index" has no attribute "dayofweek"
|
||||
weekdays = np.unique(index.dayofweek) # type: ignore[attr-defined]
|
||||
if (5 in weekdays) or (6 in weekdays):
|
||||
freq = None
|
||||
|
||||
freq = to_offset(freq)
|
||||
return freq
|
||||
|
||||
|
||||
def maybe_convert_index(ax: Axes, data: NDFrameT) -> NDFrameT:
|
||||
# tsplot converts automatically, but don't want to convert index
|
||||
# over and over for DataFrames
|
||||
if isinstance(data.index, (ABCDatetimeIndex, ABCPeriodIndex)):
|
||||
freq: str | BaseOffset | None = data.index.freq
|
||||
|
||||
if freq is None:
|
||||
# We only get here for DatetimeIndex
|
||||
data.index = cast("DatetimeIndex", data.index)
|
||||
freq = data.index.inferred_freq
|
||||
freq = to_offset(freq)
|
||||
|
||||
if freq is None:
|
||||
freq = _get_ax_freq(ax)
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("Could not get frequency alias for plotting")
|
||||
|
||||
freq_str = _get_period_alias(freq)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# suppress Period[B] deprecation warning
|
||||
# TODO: need to find an alternative to this before the deprecation
|
||||
# is enforced!
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
r"PeriodDtype\[B\] is deprecated",
|
||||
category=FutureWarning,
|
||||
)
|
||||
|
||||
if isinstance(data.index, ABCDatetimeIndex):
|
||||
data = data.tz_localize(None).to_period(freq=freq_str)
|
||||
elif isinstance(data.index, ABCPeriodIndex):
|
||||
data.index = data.index.asfreq(freq=freq_str)
|
||||
return data
|
||||
|
||||
|
||||
# Patch methods for subplot.
|
||||
|
||||
|
||||
def _format_coord(freq, t, y) -> str:
|
||||
time_period = Period(ordinal=int(t), freq=freq)
|
||||
return f"t = {time_period} y = {y:8f}"
|
||||
|
||||
|
||||
def format_dateaxis(
|
||||
subplot, freq: BaseOffset, index: DatetimeIndex | PeriodIndex
|
||||
) -> None:
|
||||
"""
|
||||
Pretty-formats the date axis (x-axis).
|
||||
|
||||
Major and minor ticks are automatically set for the frequency of the
|
||||
current underlying series. As the dynamic mode is activated by
|
||||
default, changing the limits of the x axis will intelligently change
|
||||
the positions of the ticks.
|
||||
"""
|
||||
from matplotlib import pylab
|
||||
|
||||
# handle index specific formatting
|
||||
# Note: DatetimeIndex does not use this
|
||||
# interface. DatetimeIndex uses matplotlib.date directly
|
||||
if isinstance(index, ABCPeriodIndex):
|
||||
majlocator = TimeSeries_DateLocator(
|
||||
freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot
|
||||
)
|
||||
minlocator = TimeSeries_DateLocator(
|
||||
freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot
|
||||
)
|
||||
subplot.xaxis.set_major_locator(majlocator)
|
||||
subplot.xaxis.set_minor_locator(minlocator)
|
||||
|
||||
majformatter = TimeSeries_DateFormatter(
|
||||
freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot
|
||||
)
|
||||
minformatter = TimeSeries_DateFormatter(
|
||||
freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot
|
||||
)
|
||||
subplot.xaxis.set_major_formatter(majformatter)
|
||||
subplot.xaxis.set_minor_formatter(minformatter)
|
||||
|
||||
# x and y coord info
|
||||
subplot.format_coord = functools.partial(_format_coord, freq)
|
||||
|
||||
elif isinstance(index, ABCTimedeltaIndex):
|
||||
subplot.xaxis.set_major_formatter(TimeSeries_TimedeltaFormatter())
|
||||
else:
|
||||
raise TypeError("index type not supported")
|
||||
|
||||
pylab.draw_if_interactive()
|
@ -0,0 +1,492 @@
|
||||
# being a bit too dynamic
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
from typing import TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
from matplotlib import ticker
|
||||
import matplotlib.table
|
||||
import numpy as np
|
||||
|
||||
from pandas.util._exceptions import find_stack_level
|
||||
|
||||
from pandas.core.dtypes.common import is_list_like
|
||||
from pandas.core.dtypes.generic import (
|
||||
ABCDataFrame,
|
||||
ABCIndex,
|
||||
ABCSeries,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
Iterable,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.axis import Axis
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.table import Table
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
|
||||
def do_adjust_figure(fig: Figure) -> bool:
|
||||
"""Whether fig has constrained_layout enabled."""
|
||||
if not hasattr(fig, "get_constrained_layout"):
|
||||
return False
|
||||
return not fig.get_constrained_layout()
|
||||
|
||||
|
||||
def maybe_adjust_figure(fig: Figure, *args, **kwargs) -> None:
|
||||
"""Call fig.subplots_adjust unless fig has constrained_layout enabled."""
|
||||
if do_adjust_figure(fig):
|
||||
fig.subplots_adjust(*args, **kwargs)
|
||||
|
||||
|
||||
def format_date_labels(ax: Axes, rot) -> None:
|
||||
# mini version of autofmt_xdate
|
||||
for label in ax.get_xticklabels():
|
||||
label.set_horizontalalignment("right")
|
||||
label.set_rotation(rot)
|
||||
fig = ax.get_figure()
|
||||
if fig is not None:
|
||||
# should always be a Figure but can technically be None
|
||||
maybe_adjust_figure(fig, bottom=0.2)
|
||||
|
||||
|
||||
def table(
|
||||
ax, data: DataFrame | Series, rowLabels=None, colLabels=None, **kwargs
|
||||
) -> Table:
|
||||
if isinstance(data, ABCSeries):
|
||||
data = data.to_frame()
|
||||
elif isinstance(data, ABCDataFrame):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Input data must be DataFrame or Series")
|
||||
|
||||
if rowLabels is None:
|
||||
rowLabels = data.index
|
||||
|
||||
if colLabels is None:
|
||||
colLabels = data.columns
|
||||
|
||||
cellText = data.values
|
||||
|
||||
# error: Argument "cellText" to "table" has incompatible type "ndarray[Any,
|
||||
# Any]"; expected "Sequence[Sequence[str]] | None"
|
||||
return matplotlib.table.table(
|
||||
ax,
|
||||
cellText=cellText, # type: ignore[arg-type]
|
||||
rowLabels=rowLabels,
|
||||
colLabels=colLabels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _get_layout(
|
||||
nplots: int,
|
||||
layout: tuple[int, int] | None = None,
|
||||
layout_type: str = "box",
|
||||
) -> tuple[int, int]:
|
||||
if layout is not None:
|
||||
if not isinstance(layout, (tuple, list)) or len(layout) != 2:
|
||||
raise ValueError("Layout must be a tuple of (rows, columns)")
|
||||
|
||||
nrows, ncols = layout
|
||||
|
||||
if nrows == -1 and ncols > 0:
|
||||
layout = nrows, ncols = (ceil(nplots / ncols), ncols)
|
||||
elif ncols == -1 and nrows > 0:
|
||||
layout = nrows, ncols = (nrows, ceil(nplots / nrows))
|
||||
elif ncols <= 0 and nrows <= 0:
|
||||
msg = "At least one dimension of layout must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
if nrows * ncols < nplots:
|
||||
raise ValueError(
|
||||
f"Layout of {nrows}x{ncols} must be larger than required size {nplots}"
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
if layout_type == "single":
|
||||
return (1, 1)
|
||||
elif layout_type == "horizontal":
|
||||
return (1, nplots)
|
||||
elif layout_type == "vertical":
|
||||
return (nplots, 1)
|
||||
|
||||
layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)}
|
||||
try:
|
||||
return layouts[nplots]
|
||||
except KeyError:
|
||||
k = 1
|
||||
while k**2 < nplots:
|
||||
k += 1
|
||||
|
||||
if (k - 1) * k >= nplots:
|
||||
return k, (k - 1)
|
||||
else:
|
||||
return k, k
|
||||
|
||||
|
||||
# copied from matplotlib/pyplot.py and modified for pandas.plotting
|
||||
|
||||
|
||||
def create_subplots(
|
||||
naxes: int,
|
||||
sharex: bool = False,
|
||||
sharey: bool = False,
|
||||
squeeze: bool = True,
|
||||
subplot_kw=None,
|
||||
ax=None,
|
||||
layout=None,
|
||||
layout_type: str = "box",
|
||||
**fig_kw,
|
||||
):
|
||||
"""
|
||||
Create a figure with a set of subplots already made.
|
||||
|
||||
This utility wrapper makes it convenient to create common layouts of
|
||||
subplots, including the enclosing figure object, in a single call.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
naxes : int
|
||||
Number of required axes. Exceeded axes are set invisible. Default is
|
||||
nrows * ncols.
|
||||
|
||||
sharex : bool
|
||||
If True, the X axis will be shared amongst all subplots.
|
||||
|
||||
sharey : bool
|
||||
If True, the Y axis will be shared amongst all subplots.
|
||||
|
||||
squeeze : bool
|
||||
|
||||
If True, extra dimensions are squeezed out from the returned axis object:
|
||||
- if only one subplot is constructed (nrows=ncols=1), the resulting
|
||||
single Axis object is returned as a scalar.
|
||||
- for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
|
||||
array of Axis objects are returned as numpy 1-d arrays.
|
||||
- for NxM subplots with N>1 and M>1 are returned as a 2d array.
|
||||
|
||||
If False, no squeezing is done: the returned axis object is always
|
||||
a 2-d array containing Axis instances, even if it ends up being 1x1.
|
||||
|
||||
subplot_kw : dict
|
||||
Dict with keywords passed to the add_subplot() call used to create each
|
||||
subplots.
|
||||
|
||||
ax : Matplotlib axis object, optional
|
||||
|
||||
layout : tuple
|
||||
Number of rows and columns of the subplot grid.
|
||||
If not specified, calculated from naxes and layout_type
|
||||
|
||||
layout_type : {'box', 'horizontal', 'vertical'}, default 'box'
|
||||
Specify how to layout the subplot grid.
|
||||
|
||||
fig_kw : Other keyword arguments to be passed to the figure() call.
|
||||
Note that all keywords not recognized above will be
|
||||
automatically included here.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig, ax : tuple
|
||||
- fig is the Matplotlib Figure object
|
||||
- ax can be either a single axis object or an array of axis objects if
|
||||
more than one subplot was created. The dimensions of the resulting array
|
||||
can be controlled with the squeeze keyword, see above.
|
||||
|
||||
Examples
|
||||
--------
|
||||
x = np.linspace(0, 2*np.pi, 400)
|
||||
y = np.sin(x**2)
|
||||
|
||||
# Just a figure and one subplot
|
||||
f, ax = plt.subplots()
|
||||
ax.plot(x, y)
|
||||
ax.set_title('Simple plot')
|
||||
|
||||
# Two subplots, unpack the output array immediately
|
||||
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
|
||||
ax1.plot(x, y)
|
||||
ax1.set_title('Sharing Y axis')
|
||||
ax2.scatter(x, y)
|
||||
|
||||
# Four polar axes
|
||||
plt.subplots(2, 2, subplot_kw=dict(polar=True))
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if subplot_kw is None:
|
||||
subplot_kw = {}
|
||||
|
||||
if ax is None:
|
||||
fig = plt.figure(**fig_kw)
|
||||
else:
|
||||
if is_list_like(ax):
|
||||
if squeeze:
|
||||
ax = flatten_axes(ax)
|
||||
if layout is not None:
|
||||
warnings.warn(
|
||||
"When passing multiple axes, layout keyword is ignored.",
|
||||
UserWarning,
|
||||
stacklevel=find_stack_level(),
|
||||
)
|
||||
if sharex or sharey:
|
||||
warnings.warn(
|
||||
"When passing multiple axes, sharex and sharey "
|
||||
"are ignored. These settings must be specified when creating axes.",
|
||||
UserWarning,
|
||||
stacklevel=find_stack_level(),
|
||||
)
|
||||
if ax.size == naxes:
|
||||
fig = ax.flat[0].get_figure()
|
||||
return fig, ax
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The number of passed axes must be {naxes}, the "
|
||||
"same as the output plot"
|
||||
)
|
||||
|
||||
fig = ax.get_figure()
|
||||
# if ax is passed and a number of subplots is 1, return ax as it is
|
||||
if naxes == 1:
|
||||
if squeeze:
|
||||
return fig, ax
|
||||
else:
|
||||
return fig, flatten_axes(ax)
|
||||
else:
|
||||
warnings.warn(
|
||||
"To output multiple subplots, the figure containing "
|
||||
"the passed axes is being cleared.",
|
||||
UserWarning,
|
||||
stacklevel=find_stack_level(),
|
||||
)
|
||||
fig.clear()
|
||||
|
||||
nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type)
|
||||
nplots = nrows * ncols
|
||||
|
||||
# Create empty object array to hold all axes. It's easiest to make it 1-d
|
||||
# so we can just append subplots upon creation, and then
|
||||
axarr = np.empty(nplots, dtype=object)
|
||||
|
||||
# Create first subplot separately, so we can share it if requested
|
||||
ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
|
||||
|
||||
if sharex:
|
||||
subplot_kw["sharex"] = ax0
|
||||
if sharey:
|
||||
subplot_kw["sharey"] = ax0
|
||||
axarr[0] = ax0
|
||||
|
||||
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
|
||||
# convention.
|
||||
for i in range(1, nplots):
|
||||
kwds = subplot_kw.copy()
|
||||
# Set sharex and sharey to None for blank/dummy axes, these can
|
||||
# interfere with proper axis limits on the visible axes if
|
||||
# they share axes e.g. issue #7528
|
||||
if i >= naxes:
|
||||
kwds["sharex"] = None
|
||||
kwds["sharey"] = None
|
||||
ax = fig.add_subplot(nrows, ncols, i + 1, **kwds)
|
||||
axarr[i] = ax
|
||||
|
||||
if naxes != nplots:
|
||||
for ax in axarr[naxes:]:
|
||||
ax.set_visible(False)
|
||||
|
||||
handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey)
|
||||
|
||||
if squeeze:
|
||||
# Reshape the array to have the final desired dimension (nrow,ncol),
|
||||
# though discarding unneeded dimensions that equal 1. If we only have
|
||||
# one subplot, just return it instead of a 1-element array.
|
||||
if nplots == 1:
|
||||
axes = axarr[0]
|
||||
else:
|
||||
axes = axarr.reshape(nrows, ncols).squeeze()
|
||||
else:
|
||||
# returned axis array will be always 2-d, even if nrows=ncols=1
|
||||
axes = axarr.reshape(nrows, ncols)
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def _remove_labels_from_axis(axis: Axis) -> None:
|
||||
for t in axis.get_majorticklabels():
|
||||
t.set_visible(False)
|
||||
|
||||
# set_visible will not be effective if
|
||||
# minor axis has NullLocator and NullFormatter (default)
|
||||
if isinstance(axis.get_minor_locator(), ticker.NullLocator):
|
||||
axis.set_minor_locator(ticker.AutoLocator())
|
||||
if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
|
||||
axis.set_minor_formatter(ticker.FormatStrFormatter(""))
|
||||
for t in axis.get_minorticklabels():
|
||||
t.set_visible(False)
|
||||
|
||||
axis.get_label().set_visible(False)
|
||||
|
||||
|
||||
def _has_externally_shared_axis(ax1: Axes, compare_axis: str) -> bool:
|
||||
"""
|
||||
Return whether an axis is externally shared.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax1 : matplotlib.axes.Axes
|
||||
Axis to query.
|
||||
compare_axis : str
|
||||
`"x"` or `"y"` according to whether the X-axis or Y-axis is being
|
||||
compared.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
`True` if the axis is externally shared. Otherwise `False`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If two axes with different positions are sharing an axis, they can be
|
||||
referred to as *externally* sharing the common axis.
|
||||
|
||||
If two axes sharing an axis also have the same position, they can be
|
||||
referred to as *internally* sharing the common axis (a.k.a twinning).
|
||||
|
||||
_handle_shared_axes() is only interested in axes externally sharing an
|
||||
axis, regardless of whether either of the axes is also internally sharing
|
||||
with a third axis.
|
||||
"""
|
||||
if compare_axis == "x":
|
||||
axes = ax1.get_shared_x_axes()
|
||||
elif compare_axis == "y":
|
||||
axes = ax1.get_shared_y_axes()
|
||||
else:
|
||||
raise ValueError(
|
||||
"_has_externally_shared_axis() needs 'x' or 'y' as a second parameter"
|
||||
)
|
||||
|
||||
axes_siblings = axes.get_siblings(ax1)
|
||||
|
||||
# Retain ax1 and any of its siblings which aren't in the same position as it
|
||||
ax1_points = ax1.get_position().get_points()
|
||||
|
||||
for ax2 in axes_siblings:
|
||||
if not np.array_equal(ax1_points, ax2.get_position().get_points()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def handle_shared_axes(
|
||||
axarr: Iterable[Axes],
|
||||
nplots: int,
|
||||
naxes: int,
|
||||
nrows: int,
|
||||
ncols: int,
|
||||
sharex: bool,
|
||||
sharey: bool,
|
||||
) -> None:
|
||||
if nplots > 1:
|
||||
row_num = lambda x: x.get_subplotspec().rowspan.start
|
||||
col_num = lambda x: x.get_subplotspec().colspan.start
|
||||
|
||||
is_first_col = lambda x: x.get_subplotspec().is_first_col()
|
||||
|
||||
if nrows > 1:
|
||||
try:
|
||||
# first find out the ax layout,
|
||||
# so that we can correctly handle 'gaps"
|
||||
layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool_)
|
||||
for ax in axarr:
|
||||
layout[row_num(ax), col_num(ax)] = ax.get_visible()
|
||||
|
||||
for ax in axarr:
|
||||
# only the last row of subplots should get x labels -> all
|
||||
# other off layout handles the case that the subplot is
|
||||
# the last in the column, because below is no subplot/gap.
|
||||
if not layout[row_num(ax) + 1, col_num(ax)]:
|
||||
continue
|
||||
if sharex or _has_externally_shared_axis(ax, "x"):
|
||||
_remove_labels_from_axis(ax.xaxis)
|
||||
|
||||
except IndexError:
|
||||
# if gridspec is used, ax.rowNum and ax.colNum may different
|
||||
# from layout shape. in this case, use last_row logic
|
||||
is_last_row = lambda x: x.get_subplotspec().is_last_row()
|
||||
for ax in axarr:
|
||||
if is_last_row(ax):
|
||||
continue
|
||||
if sharex or _has_externally_shared_axis(ax, "x"):
|
||||
_remove_labels_from_axis(ax.xaxis)
|
||||
|
||||
if ncols > 1:
|
||||
for ax in axarr:
|
||||
# only the first column should get y labels -> set all other to
|
||||
# off as we only have labels in the first column and we always
|
||||
# have a subplot there, we can skip the layout test
|
||||
if is_first_col(ax):
|
||||
continue
|
||||
if sharey or _has_externally_shared_axis(ax, "y"):
|
||||
_remove_labels_from_axis(ax.yaxis)
|
||||
|
||||
|
||||
def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
|
||||
if not is_list_like(axes):
|
||||
return np.array([axes])
|
||||
elif isinstance(axes, (np.ndarray, ABCIndex)):
|
||||
return np.asarray(axes).ravel()
|
||||
return np.array(axes)
|
||||
|
||||
|
||||
def set_ticks_props(
|
||||
axes: Axes | Sequence[Axes],
|
||||
xlabelsize: int | None = None,
|
||||
xrot=None,
|
||||
ylabelsize: int | None = None,
|
||||
yrot=None,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
for ax in flatten_axes(axes):
|
||||
if xlabelsize is not None:
|
||||
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
|
||||
if xrot is not None:
|
||||
plt.setp(ax.get_xticklabels(), rotation=xrot)
|
||||
if ylabelsize is not None:
|
||||
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
|
||||
if yrot is not None:
|
||||
plt.setp(ax.get_yticklabels(), rotation=yrot)
|
||||
return axes
|
||||
|
||||
|
||||
def get_all_lines(ax: Axes) -> list[Line2D]:
|
||||
lines = ax.get_lines()
|
||||
|
||||
if hasattr(ax, "right_ax"):
|
||||
lines += ax.right_ax.get_lines()
|
||||
|
||||
if hasattr(ax, "left_ax"):
|
||||
lines += ax.left_ax.get_lines()
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def get_xlim(lines: Iterable[Line2D]) -> tuple[float, float]:
|
||||
left, right = np.inf, -np.inf
|
||||
for line in lines:
|
||||
x = line.get_xdata(orig=False)
|
||||
left = min(np.nanmin(x), left)
|
||||
right = max(np.nanmax(x), right)
|
||||
return left, right
|
688
lib/python3.13/site-packages/pandas/plotting/_misc.py
Normal file
688
lib/python3.13/site-packages/pandas/plotting/_misc.py
Normal file
@ -0,0 +1,688 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pandas.plotting._core import _get_plot_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
Generator,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.colors import Colormap
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.table import Table
|
||||
import numpy as np
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
|
||||
def table(ax: Axes, data: DataFrame | Series, **kwargs) -> Table:
|
||||
"""
|
||||
Helper function to convert DataFrame and Series to matplotlib.table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : Matplotlib axes object
|
||||
data : DataFrame or Series
|
||||
Data for table contents.
|
||||
**kwargs
|
||||
Keyword arguments to be passed to matplotlib.table.table.
|
||||
If `rowLabels` or `colLabels` is not specified, data index or column
|
||||
name will be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib table object
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
|
||||
>>> fix, ax = plt.subplots()
|
||||
>>> ax.axis('off')
|
||||
(0.0, 1.0, 0.0, 1.0)
|
||||
>>> table = pd.plotting.table(ax, df, loc='center',
|
||||
... cellLoc='center', colWidths=list([.2, .2]))
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.table(
|
||||
ax=ax, data=data, rowLabels=None, colLabels=None, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def register() -> None:
|
||||
"""
|
||||
Register pandas formatters and converters with matplotlib.
|
||||
|
||||
This function modifies the global ``matplotlib.units.registry``
|
||||
dictionary. pandas adds custom converters for
|
||||
|
||||
* pd.Timestamp
|
||||
* pd.Period
|
||||
* np.datetime64
|
||||
* datetime.datetime
|
||||
* datetime.date
|
||||
* datetime.time
|
||||
|
||||
See Also
|
||||
--------
|
||||
deregister_matplotlib_converters : Remove pandas formatters and converters.
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
The following line is done automatically by pandas so
|
||||
the plot can be rendered:
|
||||
|
||||
>>> pd.plotting.register_matplotlib_converters()
|
||||
|
||||
>>> df = pd.DataFrame({'ts': pd.period_range('2020', periods=2, freq='M'),
|
||||
... 'y': [1, 2]
|
||||
... })
|
||||
>>> plot = df.plot.line(x='ts', y='y')
|
||||
|
||||
Unsetting the register manually an error will be raised:
|
||||
|
||||
>>> pd.set_option("plotting.matplotlib.register_converters",
|
||||
... False) # doctest: +SKIP
|
||||
>>> df.plot.line(x='ts', y='y') # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: float() argument must be a string or a real number, not 'Period'
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
plot_backend.register()
|
||||
|
||||
|
||||
def deregister() -> None:
|
||||
"""
|
||||
Remove pandas formatters and converters.
|
||||
|
||||
Removes the custom converters added by :func:`register`. This
|
||||
attempts to set the state of the registry back to the state before
|
||||
pandas registered its own units. Converters for pandas' own types like
|
||||
Timestamp and Period are removed completely. Converters for types
|
||||
pandas overwrites, like ``datetime.datetime``, are restored to their
|
||||
original value.
|
||||
|
||||
See Also
|
||||
--------
|
||||
register_matplotlib_converters : Register pandas formatters and converters
|
||||
with matplotlib.
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
The following line is done automatically by pandas so
|
||||
the plot can be rendered:
|
||||
|
||||
>>> pd.plotting.register_matplotlib_converters()
|
||||
|
||||
>>> df = pd.DataFrame({'ts': pd.period_range('2020', periods=2, freq='M'),
|
||||
... 'y': [1, 2]
|
||||
... })
|
||||
>>> plot = df.plot.line(x='ts', y='y')
|
||||
|
||||
Unsetting the register manually an error will be raised:
|
||||
|
||||
>>> pd.set_option("plotting.matplotlib.register_converters",
|
||||
... False) # doctest: +SKIP
|
||||
>>> df.plot.line(x='ts', y='y') # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: float() argument must be a string or a real number, not 'Period'
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
plot_backend.deregister()
|
||||
|
||||
|
||||
def scatter_matrix(
|
||||
frame: DataFrame,
|
||||
alpha: float = 0.5,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
ax: Axes | None = None,
|
||||
grid: bool = False,
|
||||
diagonal: str = "hist",
|
||||
marker: str = ".",
|
||||
density_kwds: Mapping[str, Any] | None = None,
|
||||
hist_kwds: Mapping[str, Any] | None = None,
|
||||
range_padding: float = 0.05,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw a matrix of scatter plots.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : DataFrame
|
||||
alpha : float, optional
|
||||
Amount of transparency applied.
|
||||
figsize : (float,float), optional
|
||||
A tuple (width, height) in inches.
|
||||
ax : Matplotlib axis object, optional
|
||||
grid : bool, optional
|
||||
Setting this to True will show the grid.
|
||||
diagonal : {'hist', 'kde'}
|
||||
Pick between 'kde' and 'hist' for either Kernel Density Estimation or
|
||||
Histogram plot in the diagonal.
|
||||
marker : str, optional
|
||||
Matplotlib marker type, default '.'.
|
||||
density_kwds : keywords
|
||||
Keyword arguments to be passed to kernel density estimate plot.
|
||||
hist_kwds : keywords
|
||||
Keyword arguments to be passed to hist function.
|
||||
range_padding : float, default 0.05
|
||||
Relative extension of axis range in x and y with respect to
|
||||
(x_max - x_min) or (y_max - y_min).
|
||||
**kwargs
|
||||
Keyword arguments to be passed to scatter function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
A matrix of scatter plots.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
|
||||
>>> pd.plotting.scatter_matrix(df, alpha=0.2)
|
||||
array([[<Axes: xlabel='A', ylabel='A'>, <Axes: xlabel='B', ylabel='A'>,
|
||||
<Axes: xlabel='C', ylabel='A'>, <Axes: xlabel='D', ylabel='A'>],
|
||||
[<Axes: xlabel='A', ylabel='B'>, <Axes: xlabel='B', ylabel='B'>,
|
||||
<Axes: xlabel='C', ylabel='B'>, <Axes: xlabel='D', ylabel='B'>],
|
||||
[<Axes: xlabel='A', ylabel='C'>, <Axes: xlabel='B', ylabel='C'>,
|
||||
<Axes: xlabel='C', ylabel='C'>, <Axes: xlabel='D', ylabel='C'>],
|
||||
[<Axes: xlabel='A', ylabel='D'>, <Axes: xlabel='B', ylabel='D'>,
|
||||
<Axes: xlabel='C', ylabel='D'>, <Axes: xlabel='D', ylabel='D'>]],
|
||||
dtype=object)
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.scatter_matrix(
|
||||
frame=frame,
|
||||
alpha=alpha,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
grid=grid,
|
||||
diagonal=diagonal,
|
||||
marker=marker,
|
||||
density_kwds=density_kwds,
|
||||
hist_kwds=hist_kwds,
|
||||
range_padding=range_padding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def radviz(
|
||||
frame: DataFrame,
|
||||
class_column: str,
|
||||
ax: Axes | None = None,
|
||||
color: list[str] | tuple[str, ...] | None = None,
|
||||
colormap: Colormap | str | None = None,
|
||||
**kwds,
|
||||
) -> Axes:
|
||||
"""
|
||||
Plot a multidimensional dataset in 2D.
|
||||
|
||||
Each Series in the DataFrame is represented as a evenly distributed
|
||||
slice on a circle. Each data point is rendered in the circle according to
|
||||
the value on each Series. Highly correlated `Series` in the `DataFrame`
|
||||
are placed closer on the unit circle.
|
||||
|
||||
RadViz allow to project a N-dimensional data set into a 2D space where the
|
||||
influence of each dimension can be interpreted as a balance between the
|
||||
influence of all dimensions.
|
||||
|
||||
More info available at the `original article
|
||||
<https://doi.org/10.1145/331770.331775>`_
|
||||
describing RadViz.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : `DataFrame`
|
||||
Object holding the data.
|
||||
class_column : str
|
||||
Column name containing the name of the data point category.
|
||||
ax : :class:`matplotlib.axes.Axes`, optional
|
||||
A plot instance to which to add the information.
|
||||
color : list[str] or tuple[str], optional
|
||||
Assign a color to each category. Example: ['blue', 'green'].
|
||||
colormap : str or :class:`matplotlib.colors.Colormap`, default None
|
||||
Colormap to select colors from. If string, load colormap with that
|
||||
name from matplotlib.
|
||||
**kwds
|
||||
Options to pass to matplotlib scatter plotting method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`matplotlib.axes.Axes`
|
||||
|
||||
See Also
|
||||
--------
|
||||
pandas.plotting.andrews_curves : Plot clustering visualization.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> df = pd.DataFrame(
|
||||
... {
|
||||
... 'SepalLength': [6.5, 7.7, 5.1, 5.8, 7.6, 5.0, 5.4, 4.6, 6.7, 4.6],
|
||||
... 'SepalWidth': [3.0, 3.8, 3.8, 2.7, 3.0, 2.3, 3.0, 3.2, 3.3, 3.6],
|
||||
... 'PetalLength': [5.5, 6.7, 1.9, 5.1, 6.6, 3.3, 4.5, 1.4, 5.7, 1.0],
|
||||
... 'PetalWidth': [1.8, 2.2, 0.4, 1.9, 2.1, 1.0, 1.5, 0.2, 2.1, 0.2],
|
||||
... 'Category': [
|
||||
... 'virginica',
|
||||
... 'virginica',
|
||||
... 'setosa',
|
||||
... 'virginica',
|
||||
... 'virginica',
|
||||
... 'versicolor',
|
||||
... 'versicolor',
|
||||
... 'setosa',
|
||||
... 'virginica',
|
||||
... 'setosa'
|
||||
... ]
|
||||
... }
|
||||
... )
|
||||
>>> pd.plotting.radviz(df, 'Category') # doctest: +SKIP
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.radviz(
|
||||
frame=frame,
|
||||
class_column=class_column,
|
||||
ax=ax,
|
||||
color=color,
|
||||
colormap=colormap,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
|
||||
def andrews_curves(
|
||||
frame: DataFrame,
|
||||
class_column: str,
|
||||
ax: Axes | None = None,
|
||||
samples: int = 200,
|
||||
color: list[str] | tuple[str, ...] | None = None,
|
||||
colormap: Colormap | str | None = None,
|
||||
**kwargs,
|
||||
) -> Axes:
|
||||
"""
|
||||
Generate a matplotlib plot for visualizing clusters of multivariate data.
|
||||
|
||||
Andrews curves have the functional form:
|
||||
|
||||
.. math::
|
||||
f(t) = \\frac{x_1}{\\sqrt{2}} + x_2 \\sin(t) + x_3 \\cos(t) +
|
||||
x_4 \\sin(2t) + x_5 \\cos(2t) + \\cdots
|
||||
|
||||
Where :math:`x` coefficients correspond to the values of each dimension
|
||||
and :math:`t` is linearly spaced between :math:`-\\pi` and :math:`+\\pi`.
|
||||
Each row of frame then corresponds to a single curve.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : DataFrame
|
||||
Data to be plotted, preferably normalized to (0.0, 1.0).
|
||||
class_column : label
|
||||
Name of the column containing class names.
|
||||
ax : axes object, default None
|
||||
Axes to use.
|
||||
samples : int
|
||||
Number of points to plot in each curve.
|
||||
color : str, list[str] or tuple[str], optional
|
||||
Colors to use for the different classes. Colors can be strings
|
||||
or 3-element floating point RGB values.
|
||||
colormap : str or matplotlib colormap object, default None
|
||||
Colormap to select colors from. If a string, load colormap with that
|
||||
name from matplotlib.
|
||||
**kwargs
|
||||
Options to pass to matplotlib plotting method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`matplotlib.axes.Axes`
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> df = pd.read_csv(
|
||||
... 'https://raw.githubusercontent.com/pandas-dev/'
|
||||
... 'pandas/main/pandas/tests/io/data/csv/iris.csv'
|
||||
... )
|
||||
>>> pd.plotting.andrews_curves(df, 'Name') # doctest: +SKIP
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.andrews_curves(
|
||||
frame=frame,
|
||||
class_column=class_column,
|
||||
ax=ax,
|
||||
samples=samples,
|
||||
color=color,
|
||||
colormap=colormap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def bootstrap_plot(
|
||||
series: Series,
|
||||
fig: Figure | None = None,
|
||||
size: int = 50,
|
||||
samples: int = 500,
|
||||
**kwds,
|
||||
) -> Figure:
|
||||
"""
|
||||
Bootstrap plot on mean, median and mid-range statistics.
|
||||
|
||||
The bootstrap plot is used to estimate the uncertainty of a statistic
|
||||
by relying on random sampling with replacement [1]_. This function will
|
||||
generate bootstrapping plots for mean, median and mid-range statistics
|
||||
for the given number of samples of the given size.
|
||||
|
||||
.. [1] "Bootstrapping (statistics)" in \
|
||||
https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29
|
||||
|
||||
Parameters
|
||||
----------
|
||||
series : pandas.Series
|
||||
Series from where to get the samplings for the bootstrapping.
|
||||
fig : matplotlib.figure.Figure, default None
|
||||
If given, it will use the `fig` reference for plotting instead of
|
||||
creating a new one with default parameters.
|
||||
size : int, default 50
|
||||
Number of data points to consider during each sampling. It must be
|
||||
less than or equal to the length of the `series`.
|
||||
samples : int, default 500
|
||||
Number of times the bootstrap procedure is performed.
|
||||
**kwds
|
||||
Options to pass to matplotlib plotting method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.figure.Figure
|
||||
Matplotlib figure.
|
||||
|
||||
See Also
|
||||
--------
|
||||
pandas.DataFrame.plot : Basic plotting for DataFrame objects.
|
||||
pandas.Series.plot : Basic plotting for Series objects.
|
||||
|
||||
Examples
|
||||
--------
|
||||
This example draws a basic bootstrap plot for a Series.
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> s = pd.Series(np.random.uniform(size=100))
|
||||
>>> pd.plotting.bootstrap_plot(s) # doctest: +SKIP
|
||||
<Figure size 640x480 with 6 Axes>
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.bootstrap_plot(
|
||||
series=series, fig=fig, size=size, samples=samples, **kwds
|
||||
)
|
||||
|
||||
|
||||
def parallel_coordinates(
|
||||
frame: DataFrame,
|
||||
class_column: str,
|
||||
cols: list[str] | None = None,
|
||||
ax: Axes | None = None,
|
||||
color: list[str] | tuple[str, ...] | None = None,
|
||||
use_columns: bool = False,
|
||||
xticks: list | tuple | None = None,
|
||||
colormap: Colormap | str | None = None,
|
||||
axvlines: bool = True,
|
||||
axvlines_kwds: Mapping[str, Any] | None = None,
|
||||
sort_labels: bool = False,
|
||||
**kwargs,
|
||||
) -> Axes:
|
||||
"""
|
||||
Parallel coordinates plotting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : DataFrame
|
||||
class_column : str
|
||||
Column name containing class names.
|
||||
cols : list, optional
|
||||
A list of column names to use.
|
||||
ax : matplotlib.axis, optional
|
||||
Matplotlib axis object.
|
||||
color : list or tuple, optional
|
||||
Colors to use for the different classes.
|
||||
use_columns : bool, optional
|
||||
If true, columns will be used as xticks.
|
||||
xticks : list or tuple, optional
|
||||
A list of values to use for xticks.
|
||||
colormap : str or matplotlib colormap, default None
|
||||
Colormap to use for line colors.
|
||||
axvlines : bool, optional
|
||||
If true, vertical lines will be added at each xtick.
|
||||
axvlines_kwds : keywords, optional
|
||||
Options to be passed to axvline method for vertical lines.
|
||||
sort_labels : bool, default False
|
||||
Sort class_column labels, useful when assigning colors.
|
||||
**kwargs
|
||||
Options to pass to matplotlib plotting method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.axes.Axes
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> df = pd.read_csv(
|
||||
... 'https://raw.githubusercontent.com/pandas-dev/'
|
||||
... 'pandas/main/pandas/tests/io/data/csv/iris.csv'
|
||||
... )
|
||||
>>> pd.plotting.parallel_coordinates(
|
||||
... df, 'Name', color=('#556270', '#4ECDC4', '#C7F464')
|
||||
... ) # doctest: +SKIP
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.parallel_coordinates(
|
||||
frame=frame,
|
||||
class_column=class_column,
|
||||
cols=cols,
|
||||
ax=ax,
|
||||
color=color,
|
||||
use_columns=use_columns,
|
||||
xticks=xticks,
|
||||
colormap=colormap,
|
||||
axvlines=axvlines,
|
||||
axvlines_kwds=axvlines_kwds,
|
||||
sort_labels=sort_labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes:
|
||||
"""
|
||||
Lag plot for time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
series : Series
|
||||
The time series to visualize.
|
||||
lag : int, default 1
|
||||
Lag length of the scatter plot.
|
||||
ax : Matplotlib axis object, optional
|
||||
The matplotlib axis object to use.
|
||||
**kwds
|
||||
Matplotlib scatter method keyword arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.axes.Axes
|
||||
|
||||
Examples
|
||||
--------
|
||||
Lag plots are most commonly used to look for patterns in time series data.
|
||||
|
||||
Given the following time series
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> np.random.seed(5)
|
||||
>>> x = np.cumsum(np.random.normal(loc=1, scale=5, size=50))
|
||||
>>> s = pd.Series(x)
|
||||
>>> s.plot() # doctest: +SKIP
|
||||
|
||||
A lag plot with ``lag=1`` returns
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> pd.plotting.lag_plot(s, lag=1)
|
||||
<Axes: xlabel='y(t)', ylabel='y(t + 1)'>
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.lag_plot(series=series, lag=lag, ax=ax, **kwds)
|
||||
|
||||
|
||||
def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwargs) -> Axes:
|
||||
"""
|
||||
Autocorrelation plot for time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
series : Series
|
||||
The time series to visualize.
|
||||
ax : Matplotlib axis object, optional
|
||||
The matplotlib axis object to use.
|
||||
**kwargs
|
||||
Options to pass to matplotlib plotting method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.axes.Axes
|
||||
|
||||
Examples
|
||||
--------
|
||||
The horizontal lines in the plot correspond to 95% and 99% confidence bands.
|
||||
|
||||
The dashed line is 99% confidence band.
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> spacing = np.linspace(-9 * np.pi, 9 * np.pi, num=1000)
|
||||
>>> s = pd.Series(0.7 * np.random.rand(1000) + 0.3 * np.sin(spacing))
|
||||
>>> pd.plotting.autocorrelation_plot(s) # doctest: +SKIP
|
||||
"""
|
||||
plot_backend = _get_plot_backend("matplotlib")
|
||||
return plot_backend.autocorrelation_plot(series=series, ax=ax, **kwargs)
|
||||
|
||||
|
||||
class _Options(dict):
|
||||
"""
|
||||
Stores pandas plotting options.
|
||||
|
||||
Allows for parameter aliasing so you can just use parameter names that are
|
||||
the same as the plot function parameters, but is stored in a canonical
|
||||
format that makes it easy to breakdown into groups later.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
|
||||
>>> np.random.seed(42)
|
||||
>>> df = pd.DataFrame({'A': np.random.randn(10),
|
||||
... 'B': np.random.randn(10)},
|
||||
... index=pd.date_range("1/1/2000",
|
||||
... freq='4MS', periods=10))
|
||||
>>> with pd.plotting.plot_params.use("x_compat", True):
|
||||
... _ = df["A"].plot(color="r")
|
||||
... _ = df["B"].plot(color="g")
|
||||
"""
|
||||
|
||||
# alias so the names are same as plotting method parameter names
|
||||
_ALIASES = {"x_compat": "xaxis.compat"}
|
||||
_DEFAULT_KEYS = ["xaxis.compat"]
|
||||
|
||||
def __init__(self, deprecated: bool = False) -> None:
|
||||
self._deprecated = deprecated
|
||||
super().__setitem__("xaxis.compat", False)
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = self._get_canonical_key(key)
|
||||
if key not in self:
|
||||
raise ValueError(f"{key} is not a valid pandas plotting option")
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value) -> None:
|
||||
key = self._get_canonical_key(key)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key) -> None:
|
||||
key = self._get_canonical_key(key)
|
||||
if key in self._DEFAULT_KEYS:
|
||||
raise ValueError(f"Cannot remove default parameter {key}")
|
||||
super().__delitem__(key)
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
key = self._get_canonical_key(key)
|
||||
return super().__contains__(key)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the option store to its initial state
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
# error: Cannot access "__init__" directly
|
||||
self.__init__() # type: ignore[misc]
|
||||
|
||||
def _get_canonical_key(self, key):
|
||||
return self._ALIASES.get(key, key)
|
||||
|
||||
@contextmanager
|
||||
def use(self, key, value) -> Generator[_Options, None, None]:
|
||||
"""
|
||||
Temporarily set a parameter value using the with statement.
|
||||
Aliasing allowed.
|
||||
"""
|
||||
old_value = self[key]
|
||||
try:
|
||||
self[key] = value
|
||||
yield self
|
||||
finally:
|
||||
self[key] = old_value
|
||||
|
||||
|
||||
plot_params = _Options()
|
Reference in New Issue
Block a user