Skip to content

Commit c955fe0

Browse files
committed
Add test coverage to validate the functionality of the get_cluster method
1 parent 003a287 commit c955fe0

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

tests/e2e/mnist_raycluster_sdk_kind_test.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from time import sleep
44

5-
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication
5+
from codeflare_sdk import Cluster, ClusterConfiguration, get_cluster
66
from codeflare_sdk.ray.client import RayJobClient
77

88
import pytest
@@ -68,6 +68,10 @@ def run_mnist_raycluster_sdk_kind(
6868

6969
self.assert_jobsubmit_withoutlogin_kind(cluster, accelerator, number_of_gpus)
7070

71+
self.assert_get_cluster_and_jobsubmit(
72+
"mnist", self.namespace, accelerator, number_of_gpus
73+
)
74+
7175
# Assertions
7276

7377
def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpus):
@@ -105,12 +109,47 @@ def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpu
105109

106110
client.delete_job(submission_id)
107111

108-
cluster.down()
109-
110112
def assert_job_completion(self, status):
111113
if status == "SUCCEEDED":
112114
print(f"Job has completed: '{status}'")
113115
assert True
114116
else:
115117
print(f"Job has completed: '{status}'")
116118
assert False
119+
120+
def assert_get_cluster_and_jobsubmit(
121+
self, cluster_name, namespace, accelerator, number_of_gpus
122+
):
123+
# Retrieve the cluster
124+
cluster = get_cluster(cluster_name, namespace)
125+
126+
cluster.details()
127+
128+
cluster.config.verify_tls = False
129+
130+
# Initialize the job client
131+
client = cluster.job_client
132+
133+
# Submit a job and get the submission ID
134+
submission_id = client.submit_job(
135+
entrypoint="python mnist.py",
136+
runtime_env={
137+
"working_dir": "./tests/e2e/",
138+
"pip": "./tests/e2e/mnist_pip_requirements.txt",
139+
"env_vars": get_setup_env_variables(ACCELERATOR=accelerator),
140+
},
141+
entrypoint_num_gpus=number_of_gpus,
142+
)
143+
print(f"Submitted job with ID: {submission_id}")
144+
145+
# Fetch the list of jobs and validate
146+
job_list = client.list_jobs()
147+
print(f"List of Jobs: {job_list}")
148+
149+
# Validate the number of jobs in the list
150+
assert len(job_list) == 1
151+
152+
# Validate the submission ID matches
153+
assert job_list[0].submission_id == submission_id
154+
155+
cluster.down()

tests/e2e/mnist_raycluster_sdk_oauth_test.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from time import sleep
44

5-
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication
5+
from codeflare_sdk import (
6+
Cluster,
7+
ClusterConfiguration,
8+
TokenAuthentication,
9+
get_cluster,
10+
)
611
from codeflare_sdk.ray.client import RayJobClient
712

813
import pytest
@@ -68,6 +73,7 @@ def run_mnist_raycluster_sdk_oauth(self):
6873

6974
self.assert_jobsubmit_withoutLogin(cluster)
7075
self.assert_jobsubmit_withlogin(cluster)
76+
self.assert_get_cluster_and_jobsubmit("mnist", self.namespace)
7177

7278
# Assertions
7379

@@ -132,12 +138,45 @@ def assert_jobsubmit_withlogin(self, cluster):
132138

133139
client.delete_job(submission_id)
134140

135-
cluster.down()
136-
137141
def assert_job_completion(self, status):
138142
if status == "SUCCEEDED":
139143
print(f"Job has completed: '{status}'")
140144
assert True
141145
else:
142146
print(f"Job has completed: '{status}'")
143147
assert False
148+
149+
def assert_get_cluster_and_jobsubmit(self, cluster_name, namespace):
150+
# Retrieve the cluster
151+
cluster = get_cluster(cluster_name, namespace)
152+
153+
cluster.details()
154+
155+
cluster.config.verify_tls = False
156+
157+
# Initialize the job client
158+
client = cluster.job_client
159+
160+
# Submit a job and get the submission ID
161+
submission_id = client.submit_job(
162+
entrypoint="python mnist.py",
163+
runtime_env={
164+
"working_dir": "./tests/e2e/",
165+
"pip": "./tests/e2e/mnist_pip_requirements.txt",
166+
"env_vars": get_setup_env_variables(),
167+
},
168+
entrypoint_num_cpus=1,
169+
)
170+
print(f"Submitted job with ID: {submission_id}")
171+
172+
# Fetch the list of jobs and validate
173+
job_list = client.list_jobs()
174+
print(f"List of Jobs: {job_list}")
175+
176+
# Validate the number of jobs in the list
177+
assert len(job_list) == 1
178+
179+
# Validate the submission ID matches
180+
assert job_list[0].submission_id == submission_id
181+
182+
cluster.down()

0 commit comments

Comments
 (0)