diff --git a/src/mplfinance/_mplwraps.py b/src/mplfinance/_mplwraps.py index 93ff4ec7..bad47fbb 100644 --- a/src/mplfinance/_mplwraps.py +++ b/src/mplfinance/_mplwraps.py @@ -1,6 +1,8 @@ import matplotlib.pyplot as plt import matplotlib.figure as mplfigure +import matplotlib.axes as mpl_axes from mplfinance import _styles +import numpy as np """ This file contains: @@ -102,12 +104,18 @@ def subplots(self,*args,**kwargs): if 'style' in kwargs or not hasattr(self,'mpfstyle'): style = _check_for_and_apply_style(kwargs) + self.mpfstyle = style else: style = _check_for_and_apply_style(dict(style=self.mpfstyle)) axlist = mplfigure.Figure.subplots(self,*args,**kwargs) - - self.mpfstyle = style - if ax in axlist: - ax.mpfstyle = style - return fig, axlist + + if isinstance(axlist,mpl_axes.Axes): + axlist.mpfstyle = style + elif isinstance(axlist,np.ndarray): + for ax in axlist.flatten(): + ax.mpfstyle = style + else: + raise TypeError('Unexpected type ('+str(type(axlist))+') '+ + 'returned from "matplotlib.figure.Figure.subplots()"') + return axlist diff --git a/src/mplfinance/_version.py b/src/mplfinance/_version.py index 109f75a8..89621582 100644 --- a/src/mplfinance/_version.py +++ b/src/mplfinance/_version.py @@ -1,5 +1,5 @@ -version_info = (0, 12, 7, 'alpha', 1) +version_info = (0, 12, 7, 'alpha', 2) _specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}