Skip to content

Commit a12bb00

Browse files
committed
Add test coverage to validate the functionality of the get_cluster method
1 parent 003a287 commit a12bb00

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

tests/e2e/mnist_raycluster_sdk_kind_test.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from time import sleep
44

5-
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication
5+
from codeflare_sdk import Cluster, ClusterConfiguration, get_cluster
66
from codeflare_sdk.ray.client import RayJobClient
77

88
import pytest
@@ -68,6 +68,8 @@ def run_mnist_raycluster_sdk_kind(
6868

6969
self.assert_jobsubmit_withoutlogin_kind(cluster, accelerator, number_of_gpus)
7070

71+
self.assert_get_cluster_and_jobsubmit("mnist", self.namespace)
72+
7173
# Assertions
7274

7375
def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpus):
@@ -105,12 +107,45 @@ def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpu
105107

106108
client.delete_job(submission_id)
107109

108-
cluster.down()
109-
110110
def assert_job_completion(self, status):
111111
if status == "SUCCEEDED":
112112
print(f"Job has completed: '{status}'")
113113
assert True
114114
else:
115115
print(f"Job has completed: '{status}'")
116116
assert False
117+
118+
def assert_get_cluster_and_jobsubmit(self, cluster_name, namespace):
119+
# Retrieve the cluster
120+
cluster = get_cluster(cluster_name, namespace)
121+
122+
cluster.details()
123+
124+
cluster.config.verify_tls = False
125+
126+
# Initialize the job client
127+
client = cluster.job_client
128+
129+
# Submit a job and get the submission ID
130+
submission_id = client.submit_job(
131+
entrypoint="python mnist.py",
132+
runtime_env={
133+
"working_dir": "./tests/e2e/",
134+
"pip": "./tests/e2e/mnist_pip_requirements.txt",
135+
"env_vars": get_setup_env_variables(ACCELERATOR=accelerator),
136+
},
137+
entrypoint_num_gpus=number_of_gpus,
138+
)
139+
print(f"Submitted job with ID: {submission_id}")
140+
141+
# Fetch the list of jobs and validate
142+
job_list = client.list_jobs()
143+
print(f"List of Jobs: {job_list}")
144+
145+
# Validate the number of jobs in the list
146+
assert len(job_list) == 1
147+
148+
# Validate the submission ID matches
149+
assert job_list[0].submission_id == submission_id
150+
151+
cluster.down()

tests/e2e/mnist_raycluster_sdk_oauth_test.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from time import sleep
44

5-
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication
5+
from codeflare_sdk import (
6+
Cluster,
7+
ClusterConfiguration,
8+
TokenAuthentication,
9+
get_cluster,
10+
)
611
from codeflare_sdk.ray.client import RayJobClient
712

813
import pytest
@@ -68,6 +73,7 @@ def run_mnist_raycluster_sdk_oauth(self):
6873

6974
self.assert_jobsubmit_withoutLogin(cluster)
7075
self.assert_jobsubmit_withlogin(cluster)
76+
self.assert_get_cluster_and_jobsubmit("mnist", self.namespace)
7177

7278
# Assertions
7379

@@ -132,12 +138,45 @@ def assert_jobsubmit_withlogin(self, cluster):
132138

133139
client.delete_job(submission_id)
134140

135-
cluster.down()
136-
137141
def assert_job_completion(self, status):
138142
if status == "SUCCEEDED":
139143
print(f"Job has completed: '{status}'")
140144
assert True
141145
else:
142146
print(f"Job has completed: '{status}'")
143147
assert False
148+
149+
def assert_get_cluster_and_jobsubmit(self, cluster_name, namespace):
150+
# Retrieve the cluster
151+
cluster = get_cluster(cluster_name, namespace)
152+
153+
cluster.details()
154+
155+
cluster.config.verify_tls = False
156+
157+
# Initialize the job client
158+
client = cluster.job_client
159+
160+
# Submit a job and get the submission ID
161+
submission_id = client.submit_job(
162+
entrypoint="python mnist.py",
163+
runtime_env={
164+
"working_dir": "./tests/e2e/",
165+
"pip": "./tests/e2e/mnist_pip_requirements.txt",
166+
"env_vars": get_setup_env_variables(),
167+
},
168+
entrypoint_num_cpus=1,
169+
)
170+
print(f"Submitted job with ID: {submission_id}")
171+
172+
# Fetch the list of jobs and validate
173+
job_list = client.list_jobs()
174+
print(f"List of Jobs: {job_list}")
175+
176+
# Validate the number of jobs in the list
177+
assert len(job_list) == 1
178+
179+
# Validate the submission ID matches
180+
assert job_list[0].submission_id == submission_id
181+
182+
cluster.down()

0 commit comments

Comments
 (0)