|
2 | 2 |
|
3 | 3 | from time import sleep
|
4 | 4 |
|
5 |
| -from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication |
| 5 | +from codeflare_sdk import Cluster, ClusterConfiguration, get_cluster |
6 | 6 | from codeflare_sdk.ray.client import RayJobClient
|
7 | 7 |
|
8 | 8 | import pytest
|
@@ -68,6 +68,10 @@ def run_mnist_raycluster_sdk_kind(
|
68 | 68 |
|
69 | 69 | self.assert_jobsubmit_withoutlogin_kind(cluster, accelerator, number_of_gpus)
|
70 | 70 |
|
| 71 | + self.assert_get_cluster_and_jobsubmit( |
| 72 | + "mnist", self.namespace, accelerator, number_of_gpus |
| 73 | + ) |
| 74 | + |
71 | 75 | # Assertions
|
72 | 76 |
|
73 | 77 | def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpus):
|
@@ -105,12 +109,47 @@ def assert_jobsubmit_withoutlogin_kind(self, cluster, accelerator, number_of_gpu
|
105 | 109 |
|
106 | 110 | client.delete_job(submission_id)
|
107 | 111 |
|
108 |
| - cluster.down() |
109 |
| - |
110 | 112 | def assert_job_completion(self, status):
|
111 | 113 | if status == "SUCCEEDED":
|
112 | 114 | print(f"Job has completed: '{status}'")
|
113 | 115 | assert True
|
114 | 116 | else:
|
115 | 117 | print(f"Job has completed: '{status}'")
|
116 | 118 | assert False
|
| 119 | + |
| 120 | + def assert_get_cluster_and_jobsubmit( |
| 121 | + self, cluster_name, namespace, accelerator, number_of_gpus |
| 122 | + ): |
| 123 | + # Retrieve the cluster |
| 124 | + cluster = get_cluster(cluster_name, namespace) |
| 125 | + |
| 126 | + cluster.details() |
| 127 | + |
| 128 | + cluster.config.verify_tls = False |
| 129 | + |
| 130 | + # Initialize the job client |
| 131 | + client = cluster.job_client |
| 132 | + |
| 133 | + # Submit a job and get the submission ID |
| 134 | + submission_id = client.submit_job( |
| 135 | + entrypoint="python mnist.py", |
| 136 | + runtime_env={ |
| 137 | + "working_dir": "./tests/e2e/", |
| 138 | + "pip": "./tests/e2e/mnist_pip_requirements.txt", |
| 139 | + "env_vars": get_setup_env_variables(ACCELERATOR=accelerator), |
| 140 | + }, |
| 141 | + entrypoint_num_gpus=number_of_gpus, |
| 142 | + ) |
| 143 | + print(f"Submitted job with ID: {submission_id}") |
| 144 | + |
| 145 | + # Fetch the list of jobs and validate |
| 146 | + job_list = client.list_jobs() |
| 147 | + print(f"List of Jobs: {job_list}") |
| 148 | + |
| 149 | + # Validate the number of jobs in the list |
| 150 | + assert len(job_list) == 1 |
| 151 | + |
| 152 | + # Validate the submission ID matches |
| 153 | + assert job_list[0].submission_id == submission_id |
| 154 | + |
| 155 | + cluster.down() |
0 commit comments