Skip to content

Commit ba65217

Browse files
committed
Added legend support for gantt charts and added more tests
1 parent cfa066e commit ba65217

File tree

2 files changed

+110
-67
lines changed

2 files changed

+110
-67
lines changed

plotly/tests/test_core/test_tools/test_figure_factory.py

+28-16
Original file line numberDiff line numberDiff line change
@@ -1204,9 +1204,9 @@ def test_gantt_validate_colors(self):
12041204
# validate the gantt colors variable
12051205

12061206
df = [dict(Task='Job A', Start='2009-02-01',
1207-
Finish='2009-08-30', Complete=75),
1207+
Finish='2009-08-30', Complete=75, Resource='A'),
12081208
dict(Task='Job B', Start='2009-02-01',
1209-
Finish='2009-08-30', Complete=50)]
1209+
Finish='2009-08-30', Complete=50, Resource='B')]
12101210

12111211
pattern = ("Whoops! The elements in your rgb colors tuples cannot "
12121212
"exceed 255.0.")
@@ -1225,41 +1225,53 @@ def test_gantt_validate_colors(self):
12251225
tls.FigureFactory.create_gantt, df,
12261226
index_col='Complete', colors=(2, 1, 1))
12271227

1228-
pattern3 = ("You must input a valid colors. Valid types include a "
1229-
"plotly scale, rgb, hex or tuple color, a list of any "
1230-
"color types, or a dictionary with index names each "
1231-
"assigned to a color.")
1232-
1233-
self.assertRaisesRegexp(PlotlyError, pattern3,
1234-
tls.FigureFactory.create_gantt, df,
1235-
index_col='Complete', colors=5)
1236-
12371228
# verify that if colors is a dictionary, its keys span all the
12381229
# values in the index column
12391230
colors_dict = {75: 'rgb(1, 2, 3)'}
12401231

1241-
pattern4 = ("If you are using colors as a dictionary, all of its "
1232+
pattern3 = ("If you are using colors as a dictionary, all of its "
12421233
"keys must be all the values in the index column.")
12431234

1244-
self.assertRaisesRegexp(PlotlyError, pattern4,
1235+
self.assertRaisesRegexp(PlotlyError, pattern3,
12451236
tls.FigureFactory.create_gantt, df,
12461237
index_col='Complete', colors=colors_dict)
12471238

12481239
# check: index is set if colors is a dictionary
12491240
colors_dict_good = {50: 'rgb(1, 2, 3)', 75: 'rgb(5, 10, 15)'}
12501241

1251-
pattern5 = ("Error. You have set colors to a dictionary but have not "
1242+
pattern4 = ("Error. You have set colors to a dictionary but have not "
12521243
"picked an index. An index is required if you are "
12531244
"assigning colors to particular values in a dictioanry.")
12541245

1255-
self.assertRaisesRegexp(PlotlyError, pattern5,
1246+
self.assertRaisesRegexp(PlotlyError, pattern4,
12561247
tls.FigureFactory.create_gantt, df,
12571248
colors=colors_dict_good)
12581249

1250+
# check: number of colors is equal to or greater than number of
1251+
# unique index string values
1252+
pattern5 = ("Error. The number of colors in 'colors' must be no less "
1253+
"than the number of unique index values in your group "
1254+
"column.")
1255+
1256+
self.assertRaisesRegexp(PlotlyError, pattern5,
1257+
tls.FigureFactory.create_gantt, df,
1258+
index_col='Resource',
1259+
colors=['#ffffff'])
1260+
1261+
# check: if index is numeric, colors has at least 2 colors in it
1262+
pattern6 = ("You must use at least 2 colors in 'colors' if you "
1263+
"are using a colorscale. However only the first two "
1264+
"colors given will be used for the lower and upper "
1265+
"bounds on the colormap.")
1266+
1267+
self.assertRaisesRegexp(PlotlyError, pattern6,
1268+
tls.FigureFactory.create_gantt, df,
1269+
index_col='Complete',
1270+
colors=['#ffffff'])
1271+
12591272
def test_gantt_all_args(self):
12601273

12611274
# check if gantt chart matches with expected output
1262-
12631275
df = [dict(Task="Run",
12641276
Start='2010-01-01',
12651277
Finish='2011-02-02',

plotly/tools.py

+82-51
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
DEFAULT_HISTNORM = 'probability density'
5454
ALTERNATIVE_HISTNORM = 'probability'
5555

56+
5657
# Warning format
5758
def warning_on_one_line(message, category, filename, lineno,
5859
file=None, line=None):
@@ -1484,12 +1485,10 @@ def _validate_gantt(df):
14841485
# validate that df has all the required keys
14851486
for key in REQUIRED_GANTT_KEYS:
14861487
if key not in df:
1487-
raise exceptions.PlotlyError("The columns in your data"
1488-
"frame must include the "
1489-
"keys".format(
1490-
REQUIRED_GANTT_KEYS
1491-
)
1492-
)
1488+
raise exceptions.PlotlyError(
1489+
"The columns in your dataframe must include the "
1490+
"keys".format(REQUIRED_GANTT_KEYS)
1491+
)
14931492

14941493
num_of_rows = len(df.index)
14951494
chart = []
@@ -1634,8 +1633,7 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
16341633
task_names = []
16351634
if data is None:
16361635
data = []
1637-
1638-
#if chart[index_col]
1636+
showlegend = False
16391637

16401638
for index in range(len(chart)):
16411639
task = dict(x0=chart[index]['Start'],
@@ -1656,6 +1654,14 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
16561654

16571655
# compute the color for task based on indexing column
16581656
if isinstance(chart[0][index_col], Number):
1657+
# check that colors has at least 2 colors
1658+
if len(colors) < 2:
1659+
raise exceptions.PlotlyError(
1660+
"You must use at least 2 colors in 'colors' if you "
1661+
"are using a colorscale. However only the first two "
1662+
"colors given will be used for the lower and upper "
1663+
"bounds on the colormap."
1664+
)
16591665
for index in range(len(tasks)):
16601666
tn = tasks[index]['name']
16611667
task_names.append(tn)
@@ -1693,6 +1699,22 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
16931699
marker={'color': 'white'}
16941700
)
16951701
)
1702+
1703+
if show_colorbar is True:
1704+
# generate dummy data for colorscale visibility
1705+
data.append(
1706+
dict(
1707+
x=[tasks[index]['x0'], tasks[index]['x0']],
1708+
y=[index, index],
1709+
name='',
1710+
marker={'color': 'white',
1711+
'colorscale': [[0, colors[0]], [1, colors[1]]],
1712+
'showscale': True,
1713+
'cmax': 100,
1714+
'cmin': 0}
1715+
)
1716+
)
1717+
16961718
if isinstance(chart[0][index_col], str):
16971719
index_vals = []
16981720
for row in range(len(tasks)):
@@ -1701,6 +1723,13 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
17011723

17021724
index_vals.sort()
17031725

1726+
if len(colors) < len(index_vals):
1727+
raise exceptions.PlotlyError(
1728+
"Error. The number of colors in 'colors' must be no less "
1729+
"than the number of unique index values in your group "
1730+
"column."
1731+
)
1732+
17041733
# make a dictionary assignment to each index value
17051734
index_vals_dict = {}
17061735
# define color index
@@ -1733,24 +1762,27 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
17331762
)
17341763
)
17351764

1736-
if show_colorbar is True:
1737-
# generate dummy data for colorscale visibility
1738-
data.append(
1739-
dict(
1740-
x=[tasks[index]['x0'], tasks[index]['x0']],
1741-
y=[index, index],
1742-
name='',
1743-
marker={'color': 'white',
1744-
'colorscale': [[0, colors[0]], [1, colors[1]]],
1745-
'showscale': True,
1746-
'cmax': 100,
1747-
'cmin': 0}
1748-
)
1749-
)
1765+
if show_colorbar is True:
1766+
# generate dummy data to generate legend
1767+
showlegend = True
1768+
for k, index_value in enumerate(index_vals):
1769+
data.append(
1770+
dict(
1771+
x=[tasks[index]['x0'], tasks[index]['x0']],
1772+
y=[k, k],
1773+
showlegend=True,
1774+
name=str(index_value),
1775+
hoverinfo='none',
1776+
marker=dict(
1777+
color=colors[k],
1778+
size=1
1779+
)
1780+
)
1781+
)
17501782

17511783
layout = dict(
17521784
title=title,
1753-
showlegend=False,
1785+
showlegend=showlegend,
17541786
height=height,
17551787
width=width,
17561788
shapes=[],
@@ -1812,6 +1844,7 @@ def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
18121844
task_names = []
18131845
if data is None:
18141846
data = []
1847+
showlegend = False
18151848

18161849
for index in range(len(chart)):
18171850
task = dict(x0=chart[index]['Start'],
@@ -1865,24 +1898,27 @@ def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
18651898
)
18661899
)
18671900

1868-
#if show_colorbar is True:
1869-
# generate dummy data for colorscale visibility
1870-
# trace2 = dict(
1871-
# #x=[tasks[0]['x0'], tasks[0]['x0']],
1872-
# x=[2, 6],
1873-
# y=[4, 2],
1874-
# name='asdf',
1875-
# visible='legendonly',
1876-
# marker=dict(
1877-
# size=10,
1878-
# color='rgb(25, 50, 150)'),
1879-
# showlegend=True
1880-
# )
1881-
# data.append(trace2)
1901+
if show_colorbar is True:
1902+
# generate dummy data to generate legend
1903+
showlegend = True
1904+
for k, index_value in enumerate(index_vals):
1905+
data.append(
1906+
dict(
1907+
x=[tasks[index]['x0'], tasks[index]['x0']],
1908+
y=[k, k],
1909+
showlegend=True,
1910+
hoverinfo='none',
1911+
name=str(index_value),
1912+
marker=dict(
1913+
color=colors[index_value],
1914+
size=1
1915+
)
1916+
)
1917+
)
18821918

18831919
layout = dict(
18841920
title=title,
1885-
showlegend=False,
1921+
showlegend=showlegend,
18861922
height=height,
18871923
width=width,
18881924
shapes=[],
@@ -1946,9 +1982,9 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
19461982
used for indexing. If a list, its elements must be dictionaries
19471983
with the same required column headers: 'Task', 'Start' and
19481984
'Finish'.
1949-
:param (str|list|dict) colors: either a plotly scale name, an rgb
1950-
or hex color, a color tuple or a list of colors. An rgb color is
1951-
of the form 'rgb(x, y, z)' where x, y, z belong to the interval
1985+
:param (str|list|dict|tuple) colors: either a plotly scale name, an
1986+
rgb or hex color, a color tuple or a list of colors. An rgb color
1987+
is of the form 'rgb(x, y, z)' where x, y, z belong to the interval
19521988
[0, 255] and a color tuple is a tuple of the form (a, b, c) where
19531989
a, b and c belong to [0, 1]. If colors is a list, it must
19541990
contain the valid color types aforementioned as its members.
@@ -2024,7 +2060,8 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
20242060
(1, 0, 1),
20252061
'#6c4774'],
20262062
index_col='Resource',
2027-
reverse_colors=True)
2063+
reverse_colors=True,
2064+
show_colorbar=True)
20282065
20292066
# Plot the data
20302067
py.iplot(fig, filename='String Entries', world_readable=True)
@@ -2049,7 +2086,9 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
20492086
'Banana': (1, 1, 0.2)}
20502087
20512088
# Create a figure with Plotly colorscale
2052-
fig = FF.create_gantt(df, colors=colors, index_col='Resource')
2089+
fig = FF.create_gantt(df, colors=colors,
2090+
index_col='Resource',
2091+
show_colorbar=True)
20532092
20542093
# Plot the data
20552094
py.iplot(fig, filename='dictioanry colors', world_readable=True)
@@ -2095,14 +2134,6 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
20952134
return fig
20962135
else:
20972136
if not isinstance(colors, dict):
2098-
# check that colors has at least 2 colors
2099-
if len(colors) < 2:
2100-
raise exceptions.PlotlyError(
2101-
"You must use at least 2 colors in 'colors' if you "
2102-
"are using a colorscale. However only the first two "
2103-
"colors given will be used for the lower and upper "
2104-
"bounds on the colormap."
2105-
)
21062137
fig = FigureFactory._gantt_colorscale(
21072138
chart, colors, title, index_col, show_colorbar, bar_width,
21082139
showgrid_x, showgrid_y, height, width,

0 commit comments

Comments
 (0)