Skip to content

Script additions for 'ema' plotting #563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 27, 2022
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_info = (0, 12, 9, 'beta', 2)
version_info = (0, 12, 9, 'beta', 3)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
67 changes: 63 additions & 4 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,17 @@ def _valid_plot_kwargs():
'mav' : { 'Default' : None,
'Description' : 'Moving Average window size(s); (int or tuple of ints)',
'Validator' : _mav_validator },

'ema' : { 'Default' : None,
'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)',
'Validator' : _mav_validator },

'mavcolors' : { 'Default' : None,
'Description' : 'color cycle for moving averages (list or tuple of colors)'+
'(overrides mpf style mavcolors).',
'Validator' : lambda value: isinstance(value,(list,tuple)) and
all([mcolors.is_color_like(v) for v in value]) },

'renko_params' : { 'Default' : dict(),
'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`',
'Validator' : lambda value: isinstance(value,dict) },
Expand Down Expand Up @@ -450,6 +460,13 @@ def plot( data, **kwargs ):
else:
raise TypeError('style should be a `dict`; why is it not?')

if config['mavcolors'] is not None:
config['_ma_color_cycle'] = cycle(config['mavcolors'])
elif style['mavcolors'] is not None:
config['_ma_color_cycle'] = cycle(style['mavcolors'])
else:
config['_ma_color_cycle'] = None

if not external_axes_mode:
fig = plt.figure()
_adjust_figsize(fig,config)
Expand Down Expand Up @@ -528,8 +545,10 @@ def plot( data, **kwargs ):

if ptype in VALID_PMOVE_TYPES:
mavprices = _plot_mav(axA1,config,xdates,pmove_avgvals)
emaprices = _plot_ema(axA1, config, xdates, pmove_avgvals)
else:
mavprices = _plot_mav(axA1,config,xdates,closes)
emaprices = _plot_ema(axA1, config, xdates, closes)

avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
if not config['tight_layout']:
Expand Down Expand Up @@ -595,6 +614,13 @@ def plot( data, **kwargs ):
else:
for jj in range(0,len(mav)):
retdict['mav' + str(mav[jj])] = mavprices[jj]
if config['ema'] is not None:
ema = config['ema']
if len(ema) != len(emaprices):
warnings.warn('len(ema)='+str(len(ema))+' BUT len(emaprices)='+str(len(emaprices)))
else:
for jj in range(0, len(ema)):
retdict['ema' + str(ema[jj])] = emaprices[jj]
retdict['minx'] = minx
retdict['maxx'] = maxx
retdict['miny'] = miny
Expand Down Expand Up @@ -1129,10 +1155,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
if len(mavgs) > 7:
mavgs = mavgs[0:7] # take at most 7

if style['mavcolors'] is not None:
mavc = cycle(style['mavcolors'])
else:
mavc = None
mavc = config['_ma_color_cycle']

for idx,mav in enumerate(mavgs):
mean = pd.Series(prices).rolling(mav).mean()
Expand All @@ -1147,6 +1170,42 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
mavp_list.append(mavprices)
return mavp_list


def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None):
'''ema: exponential moving average'''
style = config['style']
if apmav is not None:
mavgs = apmav
else:
mavgs = config['ema']
mavp_list = []
if mavgs is not None:
shift = None
if isinstance(mavgs,dict):
shift = mavgs['shift']
mavgs = mavgs['period']
if isinstance(mavgs,int):
mavgs = mavgs, # convert to tuple
if len(mavgs) > 7:
mavgs = mavgs[0:7] # take at most 7

mavc = config['_ma_color_cycle']

for idx,mav in enumerate(mavgs):
# mean = pd.Series(prices).rolling(mav).mean()
mean = pd.Series(prices).ewm(span=mav,adjust=False).mean()
if shift is not None:
mean = mean.shift(periods=shift[idx])
emaprices = mean.values
lw = config['_width_config']['line_width']
if mavc:
ax.plot(xdates, emaprices, linewidth=lw, color=next(mavc))
else:
ax.plot(xdates, emaprices, linewidth=lw)
mavp_list.append(emaprices)
return mavp_list


def _auto_secondary_y( panels, panid, ylo, yhi ):
# If mag(nitude) for this panel is not yet set, then set it
# here, as this is the first ydata to be plotted on this panel:
Expand Down
Binary file added tests/reference_images/ema01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/reference_images/ema02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/reference_images/ema03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
124 changes: 124 additions & 0 deletions tests/test_ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import os.path
import glob
import mplfinance as mpf
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images

print('mpf.__version__ =',mpf.__version__) # for the record
print('mpf.__file__ =',mpf.__file__) # for the record
print("plt.rcParams['backend'] =",plt.rcParams['backend']) # for the record

base='ema'
tdir = os.path.join('tests','test_images')
refd = os.path.join('tests','reference_images')

globpattern = os.path.join(tdir,base+'*.png')
oldtestfiles = glob.glob(globpattern)
for fn in oldtestfiles:
try:
os.remove(fn)
except:
print('Error removing file "'+fn+'"')

IMGCOMP_TOLERANCE = 10.0 # this works fine for linux
# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it).

_df = pd.DataFrame()
def get_ema_data():
global _df
if len(_df) == 0:
_df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv',
index_col='Date',parse_dates=True)
return _df


def create_ema_image(tname):

df = get_ema_data()
df = df[-50:] # show last 50 data points only

ema25 = df['Close'].ewm(span=25.0, adjust=False).mean()
mav25 = df['Close'].rolling(window=25).mean()

ap = [
mpf.make_addplot(df, panel=1, type='ohlc', color='c',
ylabel='mpf mav', mav=25, secondary_y=False),
mpf.make_addplot(ema25, panel=2, type='line', width=2, color='c',
ylabel='calculated', secondary_y=False),
mpf.make_addplot(mav25, panel=2, type='line', width=2, color='blue',
ylabel='calculated', secondary_y=False)
]

# plot and save in `tname` path
mpf.plot(df, ylabel="mpf ema", type='ohlc',
ema=25, addplot=ap, panel_ratios=(1, 1), savefig=tname
)


def test_ema01():

fname = base+'01.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

create_ema_image(tname)

tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None

def test_ema02():
fname = base+'02.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

df = get_ema_data()
df = df[-125:-35]

mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname)

tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None

def test_ema03():
fname = base+'03.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

df = get_ema_data()
df = df[-125:-35]

mac = ['red','orange','yellow','green','blue','purple']

mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25),
mavcolors=mac, savefig=tname)


tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None