Skip to content

Commit aa79c14

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <[email protected]>
1 parent c2013ba commit aa79c14

File tree

4 files changed

+236
-4
lines changed

4 files changed

+236
-4
lines changed

src/codeflare_sdk/cluster/cluster.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
cluster setup queue, a list of all existing clusters, and the user's working namespace.
1919
"""
2020

21+
<<<<<<< HEAD
2122
from time import sleep
2223
from typing import List, Optional, Tuple, Dict
2324

25+
import openshift as oc
26+
from kubernetes import config
27+
>>>>>>> bb0a0a7 (add functions for creating ray with oauth proxy in front of the dashboard)
2428
from ray.job_submission import JobSubmissionClient
29+
import urllib3
2530

2631
from .auth import config_check, api_config_handler
2732
from ..utils import pretty_print
2833
from ..utils.generate_yaml import generate_appwrapper
2934
from ..utils.kube_api_helpers import _kube_api_error_handling
35+
from ..utils.openshift_oauth import create_openshift_oauth_objects, delete_openshift_oauth_objects, download_tls_cert
3036
from .config import ClusterConfiguration
3137
from .model import (
3238
AppWrapper,
@@ -41,6 +47,9 @@
4147
import requests
4248

4349

50+
k8_client = config.new_client_from_config()
51+
52+
4453
class Cluster:
4554
"""
4655
An object for requesting, bringing up, and taking down resources.
@@ -61,6 +70,21 @@ def __init__(self, config: ClusterConfiguration):
6170
self.config = config
6271
self.app_wrapper_yaml = self.create_app_wrapper()
6372
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
73+
self._client = None
74+
75+
@property
76+
def client(self):
77+
if self._client:
78+
return self._client
79+
if self.config.openshift_oauth:
80+
# user must be logged in to OpenShift
81+
self._client = JobSubmissionClient(
82+
self.cluster_dashboard_uri(),
83+
headers={"Authorization": k8_client.configuration.auth_settings()["BearerToken"]["value"]}
84+
)
85+
else:
86+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
87+
return self._client
6488

6589
def evaluate_dispatch_priority(self):
6690
priority_class = self.config.dispatch_priority
@@ -141,6 +165,7 @@ def create_app_wrapper(self):
141165
image_pull_secrets=image_pull_secrets,
142166
dispatch_priority=dispatch_priority,
143167
priority_val=priority_val,
168+
openshift_oauth=self.config.openshift_oauth,
144169
)
145170

146171
# creates a new cluster with the provided or default spec
@@ -150,6 +175,9 @@ def up(self):
150175
the MCAD queue.
151176
"""
152177
namespace = self.config.namespace
178+
if self.config.openshift_oauth:
179+
create_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
180+
153181
try:
154182
config_check()
155183
api_instance = client.CustomObjectsApi(api_config_handler())
@@ -184,6 +212,9 @@ def down(self):
184212
except Exception as e: # pragma: no cover
185213
return _kube_api_error_handling(e)
186214

215+
if self.config.openshift_oauth:
216+
delete_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
217+
187218
def status(
188219
self, print_to_console: bool = True
189220
) -> Tuple[CodeFlareClusterStatus, bool]:
@@ -322,14 +353,14 @@ def list_jobs(self) -> List:
322353
"""
323354
dashboard_route = self.cluster_dashboard_uri()
324355
client = JobSubmissionClient(dashboard_route)
325-
return client.list_jobs()
356+
return self.client.list_jobs()
326357

327358
def job_status(self, job_id: str) -> str:
328359
"""
329360
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
330361
"""
331362
dashboard_route = self.cluster_dashboard_uri()
332-
client = JobSubmissionClient(dashboard_route)
363+
client = JobSubmissionClient(dashboard_route,)
333364
return client.get_job_status(job_id)
334365

335366
def job_logs(self, job_id: str) -> str:
@@ -343,7 +374,7 @@ def job_logs(self, job_id: str) -> str:
343374
def torchx_config(
344375
self, working_dir: str = None, requirements: str = None
345376
) -> Dict[str, str]:
346-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
377+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
347378
to_return = {
348379
"cluster_name": self.config.name,
349380
"dashboard_address": dashboard_address,

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
local_interactive: bool = False
4949
image_pull_secrets: list = field(default_factory=list)
5050
dispatch_priority: str = None
51+
openshift_oauth: bool = False

src/codeflare_sdk/utils/generate_yaml.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@
2121
import sys
2222
import argparse
2323
import uuid
24+
<<<<<<< HEAD
2425
from kubernetes import client, config
2526
from .kube_api_helpers import _kube_api_error_handling
2627
from ..cluster.auth import api_config_handler
28+
=======
29+
from os import urandom
30+
from base64 import b64encode
31+
from urllib3.util import parse_url
32+
>>>>>>> bb0a0a7 (add functions for creating ray with oauth proxy in front of the dashboard)
2733

34+
import openshift as oc
35+
from kubernetes import client, config
36+
37+
k8_client = config.new_client_from_config()
2838

2939
def read_template(template):
3040
with open(template, "r") as stream:
@@ -46,12 +56,14 @@ def gen_names(name):
4656

4757
def update_dashboard_route(route_item, cluster_name, namespace):
4858
metadata = route_item.get("generictemplate", {}).get("metadata")
49-
metadata["name"] = f"ray-dashboard-{cluster_name}"
59+
metadata["name"] = gen_dashboard_route_name(cluster_name)
5060
metadata["namespace"] = namespace
5161
metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
5262
spec = route_item.get("generictemplate", {}).get("spec")
5363
spec["to"]["name"] = f"{cluster_name}-head-svc"
5464

65+
def gen_dashboard_route_name(cluster_name):
66+
return f"ray-dashboard-{cluster_name}"
5567

5668
# ToDo: refactor the update_x_route() functions
5769
def update_rayclient_route(route_item, cluster_name, namespace):
@@ -347,6 +359,64 @@ def write_user_appwrapper(user_yaml, output_file_name):
347359
print(f"Written to: {output_file_name}")
348360

349361

362+
def enable_openshift_oauth(user_yaml, cluster_name, namespace):
363+
tls_mount_location = "/etc/tls/private"
364+
oauth_port = 443
365+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
366+
tls_secret_name = f"{cluster_name}-proxy-tls-secret"
367+
tls_volume_name = "proxy-tls-secret"
368+
port_name = "oauth-proxy"
369+
_,_,host,_,_,_,_ = parse_url(k8_client.configuration.host)
370+
host = host.replace("api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps.")
371+
oauth_sidecar = _create_oauth_sidecar_object(
372+
namespace, tls_mount_location, oauth_port, oauth_sa_name, tls_volume_name, port_name
373+
)
374+
tls_secret_volume = client.V1Volume(
375+
name=tls_volume_name,secret=client.V1SecretVolumeSource(secret_name=tls_secret_name)
376+
)
377+
# allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster
378+
user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {})
379+
user_yaml["metadata"]["annotations"]["codeflare-sdk-use-oauth"] = "true" # if the user gets an
380+
ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]
381+
user_yaml["spec"]["resources"]["GenericItems"].pop(1)
382+
ray_headgroup_pod["serviceAccount"] = oauth_sa_name
383+
ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", [])
384+
ray_headgroup_pod["volumes"].append(k8_client.sanitize_for_serialization(tls_secret_volume))
385+
ray_headgroup_pod["containers"].append(k8_client.sanitize_for_serialization(oauth_sidecar))
386+
# add volume to headnode
387+
# add sidecar container to ray object
388+
389+
def _create_oauth_sidecar_object(
390+
namespace: str,
391+
tls_mount_location: str,
392+
oauth_port: int,
393+
oauth_sa_name: str,
394+
tls_volume_name: str,
395+
port_name: str
396+
) -> client.V1Container:
397+
return client.V1Container(
398+
args=[
399+
f"--https-address=:{oauth_port}",
400+
"--provider=openshift",
401+
f"--openshift-service-account={oauth_sa_name}",
402+
"--upstream=http://localhost:8265",
403+
f"--tls-cert={tls_mount_location}/tls.crt",
404+
f"--tls-key={tls_mount_location}/tls.key",
405+
"--cookie-secret=SECRET",
406+
# f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie
407+
f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}'
408+
],
409+
image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366",
410+
name="oauth-proxy",
411+
ports=[client.V1ContainerPort(container_port=oauth_port,name=port_name)],
412+
resources = client.V1ResourceRequirements(limits=None,requests=None),
413+
volume_mounts=[
414+
client.V1VolumeMount(
415+
mount_path=tls_mount_location,name=tls_volume_name,read_only=True
416+
)
417+
],
418+
)
419+
350420
def generate_appwrapper(
351421
name: str,
352422
namespace: str,
@@ -365,6 +435,7 @@ def generate_appwrapper(
365435
image_pull_secrets: list,
366436
dispatch_priority: str,
367437
priority_val: int,
438+
openshift_oauth: bool,
368439
):
369440
user_yaml = read_template(template)
370441
appwrapper_name, cluster_name = gen_names(name)
@@ -396,6 +467,10 @@ def generate_appwrapper(
396467
enable_local_interactive(resources, cluster_name, namespace)
397468
else:
398469
disable_raycluster_tls(resources["resources"])
470+
471+
if openshift_oauth:
472+
enable_openshift_oauth(user_yaml, cluster_name, namespace)
473+
399474
outfile = appwrapper_name + ".yaml"
400475
write_user_appwrapper(user_yaml, outfile)
401476
return outfile
+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from urllib3.util import parse_url
2+
from .generate_yaml import gen_dashboard_route_name
3+
from base64 import b64decode
4+
5+
from kubernetes import config, client
6+
7+
k8_client = config.new_client_from_config()
8+
core_api = client.CoreV1Api(k8_client)
9+
rbac_auth_api = client.RbacAuthorizationV1Api(k8_client)
10+
networking_api = client.NetworkingV1Api(k8_client)
11+
12+
def create_openshift_oauth_objects(cluster_name, namespace):
13+
oauth_port = 443
14+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
15+
tls_secret_name = _gen_tls_secret_name(cluster_name)
16+
service_name = f"{cluster_name}-oauth"
17+
port_name = "oauth-proxy"
18+
host = parse_url(k8_client.configuration.host).host
19+
20+
# replace "^api" with the expected host
21+
host = f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps" + host.lstrip("api")
22+
23+
oauth_crb = client.V1ClusterRoleBinding(
24+
api_version="rbac.authorization.k8s.io/v1", kind="ClusterRoleBinding",
25+
metadata=client.V1ObjectMeta(name=f"{cluster_name}-rb"),
26+
role_ref=client.V1RoleRef(
27+
api_group="rbac.authorization.k8s.io",
28+
kind="ClusterRole",
29+
name="system:auth-delegator",
30+
),
31+
subjects=[client.V1Subject(kind="ServiceAccount", name=oauth_sa_name, namespace=namespace)]
32+
)
33+
oauth_sa = client.V1ServiceAccount(
34+
api_version="v1",
35+
kind="ServiceAccount",
36+
metadata=client.V1ObjectMeta(
37+
name=oauth_sa_name,
38+
namespace=namespace,
39+
annotations={"serviceaccounts.openshift.io/oauth-redirecturi.first": f"https://{host}"}
40+
)
41+
)
42+
oauth_service = _create_oauth_service_obj(
43+
cluster_name, namespace, oauth_port, tls_secret_name, service_name, port_name
44+
)
45+
ingress = _create_oauth_ingress_object(cluster_name, namespace, service_name, port_name, host)
46+
core_api.create_namespaced_service_account(namespace=namespace, body=oauth_sa)
47+
core_api.create_namespaced_service(namespace=namespace, body=oauth_service)
48+
networking_api.create_namespaced_ingress(namespace=namespace, body=ingress)
49+
rbac_auth_api.create_cluster_role_binding(body=oauth_crb)
50+
51+
def _gen_tls_secret_name(cluster_name):
52+
return f"{cluster_name}-proxy-tls-secret"
53+
54+
def delete_openshift_oauth_objects(cluster_name, namespace):
55+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
56+
service_name = f"{cluster_name}-oauth"
57+
core_api.delete_namespaced_service_account(name=oauth_sa_name, namespace=namespace)
58+
core_api.delete_namespaced_service(name=service_name, namespace=namespace)
59+
networking_api.delete_namespaced_ingress(name=f"{cluster_name}-ingress", namespace=namespace)
60+
rbac_auth_api.delete_cluster_role_binding(name= f"{cluster_name}-rb")
61+
62+
def download_tls_cert(cluster_name, namespace, output_file):
63+
b64_tls_cert = core_api.read_namespaced_secret(
64+
name=_gen_tls_secret_name(cluster_name=cluster_name),namespace=namespace
65+
).data['tls.crt']
66+
with open(output_file, "w+") as f:
67+
f.write(b64decode(b64_tls_cert).decode("ascii"))
68+
69+
def _create_oauth_service_obj(
70+
cluster_name: str,
71+
namespace: str,
72+
oauth_port: int,
73+
tls_secret_name: str,
74+
service_name: str,
75+
port_name: str,
76+
) -> client.V1Service:
77+
return client.V1Service(
78+
api_version="v1",
79+
kind="Service",
80+
metadata=client.V1ObjectMeta(
81+
annotations={"service.beta.openshift.io/serving-cert-secret-name": tls_secret_name},
82+
name=service_name,
83+
namespace=namespace
84+
),
85+
spec=client.V1ServiceSpec(
86+
ports=[client.V1ServicePort(name=port_name, protocol="TCP", port=oauth_port, target_port=oauth_port)],
87+
selector={
88+
"app.kubernetes.io/created-by": "kuberay-operator",
89+
"app.kubernetes.io/name": "kuberay",
90+
"ray.io/cluster": cluster_name,
91+
"ray.io/identifier": f"{cluster_name}-head",
92+
"ray.io/node-type": "head",
93+
}
94+
)
95+
)
96+
97+
def _create_oauth_ingress_object(
98+
cluster_name: str,
99+
namespace: str,
100+
service_name: str,
101+
port_name: str,
102+
host: str,
103+
) -> client.V1Ingress:
104+
return client.V1Ingress(
105+
api_version="networking.k8s.io/v1",
106+
kind="Ingress",
107+
metadata=client.V1ObjectMeta(
108+
annotations={"route.openshift.io/termination": "passthrough"},
109+
name=f"{cluster_name}-ingress",
110+
namespace=namespace
111+
),
112+
spec=client.V1IngressSpec(rules=[client.V1IngressRule(
113+
host=host,
114+
http=client.V1HTTPIngressRuleValue(paths=[
115+
client.V1HTTPIngressPath(
116+
backend=client.V1IngressBackend(
117+
service=client.V1IngressServiceBackend(
118+
name=service_name,port=client.V1ServiceBackendPort(name=port_name)
119+
)
120+
),
121+
path_type="ImplementationSpecific"
122+
)
123+
])
124+
)]),
125+
)

0 commit comments

Comments
 (0)