Skip to content

Commit 5560011

Browse files
Setting time series regression default data checks to regression data checks.
1 parent 5ac496d commit 5560011

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

evalml/data_checks/default_data_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, problem_type):
3232
Arguments:
3333
problem_type (str): The problem type that is being validated. Can be regression, binary, or multiclass.
3434
"""
35-
if handle_problem_types(problem_type) == ProblemTypes.REGRESSION:
35+
if handle_problem_types(problem_type) in [ProblemTypes.REGRESSION, ProblemTypes.TIME_SERIES_REGRESSION]:
3636
super().__init__(self._DEFAULT_DATA_CHECK_CLASSES,
3737
data_check_params={"InvalidTargetDataCheck": {"problem_type": problem_type}})
3838
else:

evalml/tests/automl_tests/test_automl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1828,7 +1828,7 @@ class Pipeline2(TimeSeriesRegressionPipeline):
18281828

18291829
automl = AutoMLSearch(problem_type="time series regression", problem_configuration=configuration,
18301830
allowed_pipelines=[Pipeline1, Pipeline2], max_iterations=4)
1831-
automl.search(X, y, data_checks='disabled')
1831+
automl.search(X, y)
18321832
assert isinstance(automl.data_split, TimeSeriesSplit)
18331833
for result in automl.results['pipeline_results'].values():
18341834
if result["id"] == 0:

evalml/tests/data_checks_tests/test_data_checks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def test_default_data_checks_regression():
105105
assert data_checks.validate(X, y) == {"warnings": messages[:3], "errors": messages[3:]}
106106

107107

108+
def test_default_data_checks_time_series_regression():
109+
regression_data_check_classes = [check.__class__ for check in DefaultDataChecks("regression").data_checks]
110+
ts_regression_data_check_classes = [check.__class__ for check in DefaultDataChecks("time series regression").data_checks]
111+
assert regression_data_check_classes == ts_regression_data_check_classes
112+
113+
108114
def test_data_checks_init_from_classes():
109115
def make_mock_data_check(check_name):
110116
class MockCheck(DataCheck):

0 commit comments

Comments
 (0)