Skip to content

Commit 4d7d5d6

Browse files
committed
add tests for replace and generate sidecar
Signed-off-by: Kevin <[email protected]>
1 parent f513e3c commit 4d7d5d6

File tree

6 files changed

+97
-26
lines changed

6 files changed

+97
-26
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ..utils.openshift_oauth import (
3434
create_openshift_oauth_objects,
3535
delete_openshift_oauth_objects,
36-
download_tls_cert,
3736
)
3837
from .config import ClusterConfiguration
3938
from .model import (

src/codeflare_sdk/job/jobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
if TYPE_CHECKING:
3030
from ..cluster.cluster import Cluster
3131
from ..cluster.cluster import get_current_namespace
32-
from ..utils.openshift_oauth import download_tls_cert
3332

3433
all_jobs: List["Job"] = []
3534

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
from kubernetes import client, config
3232

33+
from .kube_api_helpers import _get_api_host
34+
3335

3436
def read_template(template):
3537
with open(template, "r") as stream:
@@ -387,7 +389,7 @@ def enable_openshift_oauth(user_yaml, cluster_name, namespace):
387389
tls_secret_name = f"{cluster_name}-proxy-tls-secret"
388390
tls_volume_name = "proxy-tls-secret"
389391
port_name = "oauth-proxy"
390-
_, _, host, _, _, _, _ = parse_url(k8_client.configuration.host)
392+
host = _get_api_host(k8_client)
391393
host = host.replace(
392394
"api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps."
393395
)
@@ -414,14 +416,14 @@ def enable_openshift_oauth(user_yaml, cluster_name, namespace):
414416
user_yaml["spec"]["resources"]["GenericItems"].pop(1)
415417
ray_headgroup_pod["serviceAccount"] = oauth_sa_name
416418
ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", [])
419+
420+
# we use a generic api client here so that the serialization function doesn't need to be mocked for unit tests
417421
ray_headgroup_pod["volumes"].append(
418-
k8_client.sanitize_for_serialization(tls_secret_volume)
422+
client.ApiClient().sanitize_for_serialization(tls_secret_volume)
419423
)
420424
ray_headgroup_pod["containers"].append(
421-
k8_client.sanitize_for_serialization(oauth_sidecar)
425+
client.ApiClient().sanitize_for_serialization(oauth_sidecar)
422426
)
423-
# add volume to headnode
424-
# add sidecar container to ray object
425427

426428

427429
def _create_oauth_sidecar_object(

src/codeflare_sdk/utils/kube_api_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import executing
2121
from kubernetes import client, config
22+
from urllib3.util import parse_url
2223

2324

2425
# private methods
@@ -42,3 +43,7 @@ def _kube_api_error_handling(e: Exception): # pragma: no cover
4243
elif e.reason == "Conflict":
4344
raise FileExistsError(exists_msg)
4445
raise e
46+
47+
48+
def _get_api_host(api_client: client.ApiClient): # pragma: no cover
49+
return parse_url(api_client.configuration.host).host

src/codeflare_sdk/utils/openshift_oauth.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from urllib3.util import parse_url
22
from .generate_yaml import gen_dashboard_route_name
3+
from .kube_api_helpers import _get_api_host
34
from base64 import b64decode
45

56
from ..cluster.auth import config_check, api_config_handler
67

78
from kubernetes import client
89

910

10-
def _get_api_host(api_client: client.ApiClient):
11-
return parse_url(api_client.configuration.host).host
12-
13-
1411
def create_openshift_oauth_objects(cluster_name, namespace):
1512
config_check()
1613
api_client = api_config_handler()
@@ -118,19 +115,6 @@ def delete_openshift_oauth_objects(cluster_name, namespace):
118115
)
119116

120117

121-
def download_tls_cert(cluster_name, namespace, output_file):
122-
api_client = api_config_handler()
123-
b64_tls_cert = (
124-
client.CoreV1Api(api_client)
125-
.read_namespaced_secret(
126-
name=_gen_tls_secret_name(cluster_name=cluster_name), namespace=namespace
127-
)
128-
.data["tls.crt"]
129-
)
130-
with open(output_file, "w+") as f:
131-
f.write(b64decode(b64_tls_cert).decode("ascii"))
132-
133-
134118
def _create_or_replace_oauth_service_obj(
135119
cluster_name: str,
136120
namespace: str,

tests/unit_test.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Authentication,
4141
KubeConfigFileAuthentication,
4242
config_check,
43-
api_config_handler,
4443
)
4544
from codeflare_sdk.utils.openshift_oauth import create_openshift_oauth_objects
4645
from codeflare_sdk.utils.pretty_print import (
@@ -76,6 +75,8 @@
7675
createDDPJob_with_cluster,
7776
)
7877

78+
import codeflare_sdk.utils.kube_api_helpers
79+
7980
import openshift
8081
from openshift.selector import Selector
8182
import ray
@@ -2295,7 +2296,6 @@ def test_export_env():
22952296
)
22962297

22972298

2298-
# TODO add checks to make sure that when calling generate oauth objects that the fields that are expected to match do
22992299
def test_create_openshift_oauth(mocker: MockerFixture):
23002300
create_namespaced_service_account = MagicMock()
23012301
create_cluster_role_binding = MagicMock()
@@ -2345,6 +2345,88 @@ def test_create_openshift_oauth(mocker: MockerFixture):
23452345
)
23462346

23472347

2348+
def test_replace_openshift_oauth(mocker: MockerFixture):
2349+
# not_found_exception = client.ApiException(reason="Conflict")
2350+
create_namespaced_service_account = MagicMock(
2351+
side_effect=client.ApiException(reason="Conflict")
2352+
)
2353+
create_cluster_role_binding = MagicMock(
2354+
side_effect=client.ApiException(reason="Conflict")
2355+
)
2356+
create_namespaced_service = MagicMock(
2357+
side_effect=client.ApiException(reason="Conflict")
2358+
)
2359+
create_namespaced_ingress = MagicMock(
2360+
side_effect=client.ApiException(reason="Conflict")
2361+
)
2362+
mocker.patch.object(
2363+
client.CoreV1Api,
2364+
"create_namespaced_service_account",
2365+
create_namespaced_service_account,
2366+
)
2367+
mocker.patch.object(
2368+
client.RbacAuthorizationV1Api,
2369+
"create_cluster_role_binding",
2370+
create_cluster_role_binding,
2371+
)
2372+
mocker.patch.object(
2373+
client.CoreV1Api, "create_namespaced_service", create_namespaced_service
2374+
)
2375+
mocker.patch.object(
2376+
client.NetworkingV1Api, "create_namespaced_ingress", create_namespaced_ingress
2377+
)
2378+
mocker.patch(
2379+
"codeflare_sdk.utils.openshift_oauth._get_api_host", return_value="foo.com"
2380+
)
2381+
replace_namespaced_service_account = MagicMock()
2382+
replace_cluster_role_binding = MagicMock()
2383+
replace_namespaced_service = MagicMock()
2384+
replace_namespaced_ingress = MagicMock()
2385+
mocker.patch.object(
2386+
client.CoreV1Api,
2387+
"replace_namespaced_service_account",
2388+
replace_namespaced_service_account,
2389+
)
2390+
mocker.patch.object(
2391+
client.RbacAuthorizationV1Api,
2392+
"replace_cluster_role_binding",
2393+
replace_cluster_role_binding,
2394+
)
2395+
mocker.patch.object(
2396+
client.CoreV1Api, "replace_namespaced_service", replace_namespaced_service
2397+
)
2398+
mocker.patch.object(
2399+
client.NetworkingV1Api, "replace_namespaced_ingress", replace_namespaced_ingress
2400+
)
2401+
create_openshift_oauth_objects("foo", "bar")
2402+
replace_namespaced_service_account.assert_called_once()
2403+
replace_cluster_role_binding.assert_called_once()
2404+
replace_namespaced_service.assert_called_once()
2405+
replace_namespaced_ingress.assert_called_once()
2406+
2407+
2408+
def test_gen_app_wrapper_with_oauth(mocker: MockerFixture):
2409+
mocker.patch(
2410+
"codeflare_sdk.utils.generate_yaml._get_api_host", return_value="foo.com"
2411+
)
2412+
mocker.patch(
2413+
"codeflare_sdk.cluster.cluster.get_current_namespace",
2414+
return_value="opendatahub",
2415+
)
2416+
write_user_appwrapper = MagicMock()
2417+
mocker.patch(
2418+
"codeflare_sdk.utils.generate_yaml.write_user_appwrapper", write_user_appwrapper
2419+
)
2420+
Cluster(ClusterConfiguration("test_cluster", openshift_oauth=True))
2421+
user_yaml = write_user_appwrapper.call_args.args[0]
2422+
assert any(
2423+
container["name"] == "oauth-proxy"
2424+
for container in user_yaml["spec"]["resources"]["GenericItems"][0][
2425+
"generictemplate"
2426+
]["spec"]["headGroupSpec"]["template"]["spec"]["containers"]
2427+
)
2428+
2429+
23482430
# Make sure to always keep this function last
23492431
def test_cleanup():
23502432
os.remove("unit-test-cluster.yaml")

0 commit comments

Comments
 (0)