diff --git a/.dockerignore b/.dockerignore index 97f42add..79d8f756 100644 --- a/.dockerignore +++ b/.dockerignore @@ -19,3 +19,5 @@ coverage.xml .git redis-data db-data +cluster-db-data +./workers/cluster-db-data \ No newline at end of file diff --git a/.github/workflows/dev-workers-deploy.yml b/.github/workflows/dev-workers-deploy.yml index 2923b154..5d5c3a3c 100644 --- a/.github/workflows/dev-workers-deploy.yml +++ b/.github/workflows/dev-workers-deploy.yml @@ -20,10 +20,10 @@ env: REGISTRY_HOSTNAME: gcr.io PROJECT: ${{ secrets.DEV_GKE_PROJECT }} - HOST: dev.compute.studio TAG: ${{ github.sha }} CS_CONFIG: ${{ secrets.DEV_CS_CONFIG }} + DEV_WORKERS_VALUES: ${{ secrets.DEV_WORKERS_VALUES }} jobs: setup-build-publish-deploy: @@ -58,9 +58,10 @@ jobs: # Set up docker to authenticate gcloud auth configure-docker - - name: Set cs-config.yaml file. + - name: Set cs-config.yaml and values.yaml files. run: | echo $CS_CONFIG | base64 --decode > cs-config.yaml + echo $DEV_WORKERS_VALUES | base64 --decode > ./workers/values.deploy.yaml - name: Build Docker Images run: | @@ -74,5 +75,12 @@ jobs: - name: Deploy run: | gcloud container clusters get-credentials $GKE_CLUSTER --zone $GKE_ZONE --project $GKE_PROJECT - cs workers svc config -o - --update-dns | kubectl apply -f - - kubectl get pods -o wide + cd workers + helm template cs-workers \ + --set project=$PROJECT \ + --set tag=$TAG \ + --set api.secret_key=$(cs secrets get WORKERS_API_SECRET_KEY) \ + --set db.password=$(cs secrets get WORKERS_DB_PASSWORD) \ + --set redis.password=$(cs secrets get WORKERS_REDIS_PASSWORD) \ + --namespace workers \ + -f values.deploy.yaml | kubectl apply -f - diff --git a/.github/workflows/workers-deploy.yml b/.github/workflows/workers-deploy.yml index 8042c737..4f33e6f3 100644 --- a/.github/workflows/workers-deploy.yml +++ b/.github/workflows/workers-deploy.yml @@ -20,10 +20,10 @@ env: REGISTRY_HOSTNAME: gcr.io PROJECT: ${{ secrets.GKE_PROJECT }} - HOST: compute.studio TAG: ${{ github.sha }} CS_CONFIG: ${{ secrets.CS_CONFIG }} + WORKERS_VALUES: ${{ secrets.DEV_WORKERS_VALUES }} jobs: setup-build-publish-deploy: @@ -58,9 +58,10 @@ jobs: # Set up docker to authenticate gcloud auth configure-docker - - name: Set cs-config.yaml file. + - name: Set cs-config.yaml and values.yaml files. run: | echo $CS_CONFIG | base64 --decode > cs-config.yaml + echo $WORKERS_VALUES | base64 --decode > ./workers/values.deploy.yaml - name: Build Docker Images run: | @@ -74,5 +75,12 @@ jobs: - name: Deploy run: | gcloud container clusters get-credentials $GKE_CLUSTER --zone $GKE_ZONE --project $GKE_PROJECT - cs workers svc config -o - --update-dns | kubectl apply -f - - kubectl get pods -o wide + cd workers + helm template cs-workers \ + --set project=$PROJECT \ + --set tag=$TAG \ + --set api.secret_key=$(cs secrets get WORKERS_API_SECRET_KEY) \ + --set db.password=$(cs secrets get WORKERS_DB_PASSWORD) \ + --set redis.password=$(cs secrets get WORKERS_REDIS_PASSWORD) \ + --namespace workers \ + -f values.deploy.yaml | kubectl apply -f - diff --git a/kind-config.yaml b/kind-config.yaml index fc35654b..f5557d68 100644 --- a/kind-config.yaml +++ b/kind-config.yaml @@ -4,8 +4,12 @@ nodes: - role: control-plane - role: worker extraMounts: - - hostPath: /home/hankdoupe/compute-studio/redis-data - containerPath: /redis-data + - hostPath: /home/hankdoupe/compute-studio/redis-queue-data + containerPath: /redis-queue-data + - hostPath: /home/hankdoupe/compute-studio/workers-db-data + containerPath: /workers-db-data + - hostPath: /home/hankdoupe/compute-studio + containerPath: /code - role: worker - role: worker extraMounts: diff --git a/pytest.ini b/pytest.ini index ca5c3a32..1aba610a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,4 @@ DJANGO_SETTINGS_MODULE = webapp.settings markers = register: for the register module. + requires_stripe: whether to use stripe or not. \ No newline at end of file diff --git a/src/Simulation/API.ts b/src/Simulation/API.ts index 676183c0..0cd68377 100644 --- a/src/Simulation/API.ts +++ b/src/Simulation/API.ts @@ -43,43 +43,42 @@ export default class API { }); } - getInputsDetail(): Promise { + async getInputsDetail(): Promise { if (!this.modelpk) return; - return axios - .get(`/${this.owner}/${this.title}/api/v1/${this.modelpk}/edit/`) - .then(resp => resp.data); + const resp = await axios.get(`/${this.owner}/${this.title}/api/v1/${this.modelpk}/edit/`); + return resp.data; } - getInitialValues(): Promise { - let data: Inputs; - if (!this.modelpk) { - return axios.get(`/${this.owner}/${this.title}/api/v1/inputs/`).then(inputsResp => { - data = inputsResp.data; - return data; + async getInputs(meta_parameters?: InputsDetail["meta_parameters"]): Promise { + let resp; + if (!!meta_parameters) { + resp = await axios.post(`/${this.owner}/${this.title}/api/v1/inputs/`, meta_parameters); + } else { + resp = await axios.get(`/${this.owner}/${this.title}/api/v1/inputs/`); + } + if (resp.status === 202) { + return new Promise(resolve => { + setTimeout(async () => resolve(await this.getInputs(meta_parameters)), 2000); }); } else { - return axios - .get(`/${this.owner}/${this.title}/api/v1/${this.modelpk}/edit/`) - .then(detailResp => { - return axios - .post(`/${this.owner}/${this.title}/api/v1/inputs/`, { - meta_parameters: detailResp.data.meta_parameters, - }) - .then(inputsResp => { - data = inputsResp.data; - data["detail"] = detailResp.data; - return data; - }); - }); + return resp.data; } } - resetInitialValues(metaParameters: { [metaParam: string]: any }): Promise { - return axios - .post(`/${this.owner}/${this.title}/api/v1/inputs/`, metaParameters) - .then(response => { - return response.data; + async resetInitialValues(metaParameters: { [metaParam: string]: any }): Promise { + let resp; + if (!!metaParameters) { + resp = await axios.post(`/${this.owner}/${this.title}/api/v1/inputs/`, metaParameters); + } else { + resp = await axios.get(`/${this.owner}/${this.title}/api/v1/inputs/`); + } + if (resp.status === 202) { + return new Promise(resolve => { + setTimeout(async () => resolve(await this.getInputs(metaParameters)), 2000); }); + } else { + return resp.data; + } } getAccessStatus(): Promise { diff --git a/src/Simulation/index.tsx b/src/Simulation/index.tsx index ca0dbca5..ba41fd20 100755 --- a/src/Simulation/index.tsx +++ b/src/Simulation/index.tsx @@ -128,79 +128,78 @@ class SimTabs extends React.Component< this.handleSubmit = this.handleSubmit.bind(this); } - componentDidMount() { + async componentDidMount() { this.api.getAccessStatus().then(data => { this.setState({ accessStatus: data, }); }); - this.api - .getInitialValues() - .then(data => { - const [serverValues, sects, inputs, schema, unknownParams] = convertToFormik(data); - let isEmpty = true; - for (const msectvals of Object.values(data.detail?.adjustment || {})) { - if (Object.keys(msectvals).length > 0) { - isEmpty = false; - } - } - let initialValues; - if (isEmpty) { - const storage = Persist.pop( - `${this.props.match.params.owner}/${this.props.match.params.title}/inputs` - ); - // Use values from local storage if available. Default to empty dict from server. - initialValues = storage || serverValues; - } else { - initialValues = serverValues; - } - - this.setState({ - inputs: inputs, - initialValues: initialValues, - sects: sects, - schema: schema, - unknownParams: unknownParams, - extend: "extend" in data ? data.extend : false, - }); - }) - .catch(error => { - this.setState({ error }); - }); if (this.api.modelpk) { this.setOutputs(); } + let data: Inputs; + if (this.api.modelpk) { + const detail = await this.api.getInputsDetail(); + data = await this.api.getInputs(detail.meta_parameters); + data.detail = detail; + } else { + data = await this.api.getInputs(); + } + + const [serverValues, sects, inputs, schema, unknownParams] = convertToFormik(data); + let isEmpty = true; + for (const msectvals of Object.values(data.detail?.adjustment || {})) { + if (Object.keys(msectvals).length > 0) { + isEmpty = false; + } + } + let initialValues; + if (isEmpty) { + const storage = Persist.pop( + `${this.props.match.params.owner}/${this.props.match.params.title}/inputs` + ); + // Use values from local storage if available. Default to empty dict from server. + initialValues = storage || serverValues; + } else { + initialValues = serverValues; + } + + this.setState({ + inputs: inputs, + initialValues: initialValues, + sects: sects, + schema: schema, + unknownParams: unknownParams, + extend: "extend" in inputs ? inputs.extend : false, + }); } - resetInitialValues(metaParameters: InputsDetail["meta_parameters"]) { + async resetInitialValues(metaParameters: InputsDetail["meta_parameters"]) { this.setState({ resetting: true }); - this.api - .resetInitialValues({ - meta_parameters: tbLabelSchema.cast(metaParameters), - }) - .then(data => { - const [ - initialValues, - sects, - { meta_parameters, model_parameters }, - schema, - unknownParams, - ] = convertToFormik(data); - this.setState(prevState => ({ - inputs: { - ...prevState.inputs, - ...{ - meta_parameters: meta_parameters, - model_parameters: model_parameters, - }, - }, - initialValues: initialValues, - sects: sects, - schema: schema, - unknownParams: unknownParams, - resetting: false, - })); - }); + const data = await this.api.resetInitialValues({ + meta_parameters: tbLabelSchema.cast(metaParameters), + }); + const [ + initialValues, + sects, + { meta_parameters, model_parameters }, + schema, + unknownParams, + ] = convertToFormik(data); + this.setState(prevState => ({ + inputs: { + ...prevState.inputs, + ...{ + meta_parameters: meta_parameters, + model_parameters: model_parameters, + }, + }, + initialValues: initialValues, + sects: sects, + schema: schema, + unknownParams: unknownParams, + resetting: false, + })); } resetAccessStatus() { diff --git a/webapp/apps/comp/asyncsubmit.py b/webapp/apps/comp/asyncsubmit.py index 3a87f896..46cb8edc 100755 --- a/webapp/apps/comp/asyncsubmit.py +++ b/webapp/apps/comp/asyncsubmit.py @@ -128,7 +128,11 @@ def submit(self): project = self.sim.project tag = str(project.latest_tag) self.submitted_id = self.compute.submit_job( - project=inputs.project, task_name=actions.SIM, task_kwargs=data, tag=tag, + project=inputs.project, + task_name=actions.SIM, + task_kwargs=data, + tag=tag, + path_prefix="/api/v1/jobs" if project.cluster.version == "v1" else "", ) print(f"job id: {self.submitted_id}") diff --git a/webapp/apps/comp/compute.py b/webapp/apps/comp/compute.py index 316673c3..bfd3e8da 100755 --- a/webapp/apps/comp/compute.py +++ b/webapp/apps/comp/compute.py @@ -28,10 +28,14 @@ def remote_submit_job( response = requests.post(url, json=data, timeout=timeout, headers=headers) return response - def submit_job(self, project, task_name, task_kwargs, tag=None): - print("submitting", task_name) + def submit_job(self, project, task_name, task_kwargs, path_prefix="", tag=None): + print( + "submitting", task_name, + ) cluster = project.cluster - url = f"{cluster.url}/{project.owner}/{project.title}/" + tag = tag or str(project.latest_tag) + url = f"{cluster.url}{path_prefix}/{project.owner}/{project.title}/" + print(url) return self.submit( tasks=dict(task_name=task_name, tag=tag, task_kwargs=task_kwargs), url=url, @@ -43,16 +47,17 @@ def submit(self, tasks, url, headers): attempts = 0 while not submitted: try: + print(tasks) response = self.remote_submit_job( url, data=tasks, timeout=TIMEOUT_IN_SECONDS, headers=headers ) - if response.status_code == 200: + if response.status_code in (200, 201): print("submitted: ", url) submitted = True data = response.json() - job_id = data["task_id"] + job_id = data.get("task_id") or data.get("id") else: - print("FAILED: ", url, response.status_code) + print("FAILED: ", url, response.status_code, response.json()) attempts += 1 except Timeout: print("Couldn't submit to: ", url) @@ -83,7 +88,7 @@ def submit(self, tasks, url, headers): return data = response.json() else: - print("FAILED: ", url, response.status_code) + print("FAILED: ", url, response.status_code, response.text) attempts += 1 except Timeout: print("Couldn't submit to: ", url) @@ -95,15 +100,19 @@ def submit(self, tasks, url, headers): print("Exceeded max attempts. Bailing out.") raise WorkersUnreachableError() - success = data["status"] == "SUCCESS" - if success: - return success, data + if isinstance(data, list): + success = True else: - return success, data + success = data["status"] == "SUCCESS" + + return success, data class SyncProjects(SyncCompute): def submit_job(self, project, cluster): - url = f"{cluster.url}/sync/" + if cluster.version == "v0": + url = f"{cluster.url}/sync/" + else: + url = f"{cluster.url}/api/v1/projects/sync/" headers = cluster.headers() return self.submit(tasks=[project], url=url, headers=headers) diff --git a/webapp/apps/comp/exceptions.py b/webapp/apps/comp/exceptions.py index 6b457e8a..86e6baf9 100755 --- a/webapp/apps/comp/exceptions.py +++ b/webapp/apps/comp/exceptions.py @@ -72,3 +72,15 @@ def todict(self): collaborator=getattr(self.collaborator, "username", str(self.collaborator)), msg=str(self), ) + + +class NotReady(CSException): + def __init__(self, instance, *args, **kwargs): + self.instance = instance + super().__init__(*args, **kwargs) + + +class Stale(CSException): + def __init__(self, instance, *args, **kwargs): + self.instance = instance + super().__init__(*args, **kwargs) diff --git a/webapp/apps/comp/ioutils.py b/webapp/apps/comp/ioutils.py index ac8f9621..13a9e096 100755 --- a/webapp/apps/comp/ioutils.py +++ b/webapp/apps/comp/ioutils.py @@ -8,8 +8,10 @@ class IOClasses(NamedTuple): Parser: Type[Parser] -def get_ioutils(project, **kwargs): +def get_ioutils(project, compute=None, **kwargs): return IOClasses( - model_parameters=kwargs.get("ModelParameters", ModelParameters)(project), + model_parameters=kwargs.get("ModelParameters", ModelParameters)( + project, compute=compute + ), Parser=kwargs.get("Parser", Parser), ) diff --git a/webapp/apps/comp/migrations/0029_auto_20210321_2247.py b/webapp/apps/comp/migrations/0029_auto_20210321_2247.py new file mode 100755 index 00000000..bfbaf68f --- /dev/null +++ b/webapp/apps/comp/migrations/0029_auto_20210321_2247.py @@ -0,0 +1,35 @@ +# Generated by Django 3.0.13 on 2021-03-21 22:47 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("comp", "0028_simulation_tag"), + ] + + operations = [ + migrations.AddField( + model_name="modelconfig", + name="job_id", + field=models.UUIDField(blank=True, default=None, null=True), + ), + migrations.AddField( + model_name="modelconfig", + name="status", + field=models.CharField( + choices=[ + ("STARTED", "Started"), + ("PENDING", "Pending"), + ("SUCCESS", "Success"), + ("INVALID", "Invalid"), + ("FAIL", "Fail"), + ("WORKER_FAILURE", "Worker Failure"), + ], + default="SUCCESS", + max_length=20, + ), + preserve_default=False, + ), + ] diff --git a/webapp/apps/comp/model_parameters.py b/webapp/apps/comp/model_parameters.py index 0ee4f8cb..7902c5e5 100755 --- a/webapp/apps/comp/model_parameters.py +++ b/webapp/apps/comp/model_parameters.py @@ -1,9 +1,13 @@ +from typing import Union +from django.db.models.base import Model + + import paramtools as pt from webapp.apps.comp.models import ModelConfig -from webapp.apps.comp.compute import SyncCompute, JobFailError +from webapp.apps.comp.compute import Compute, SyncCompute, JobFailError from webapp.apps.comp import actions -from webapp.apps.comp.exceptions import AppError +from webapp.apps.comp.exceptions import AppError, NotReady, Stale import os @@ -21,9 +25,16 @@ class ModelParameters: Handles logic for getting cached model parameters and updating the cache. """ - def __init__(self, project: "Project", compute: SyncCompute = None): + def __init__(self, project: "Project", compute: Union[SyncCompute, Compute] = None): self.project = project - self.compute = compute or SyncCompute() + print(self.project) + if compute is not None: + self.comptue = compute + elif self.project.cluster.version == "v0": + self.compute = compute or SyncCompute() + else: + self.compute = compute or Compute() + self.config = None def defaults(self, init_meta_parameters=None): @@ -38,9 +49,11 @@ def defaults(self, init_meta_parameters=None): "meta_parameters": meta_parameters, } - def meta_parameters_parser(self): + def meta_parameters_parser(self) -> pt.Parameters: res = self.get_inputs() - return pt_factory("MetaParametersParser", res["meta_parameters"])() + params = pt_factory("MetaParametersParser", res["meta_parameters"])() + # params._defer_validation = True + return params def model_parameters_parser(self, meta_parameters_values=None): res = self.get_inputs(meta_parameters_values) @@ -53,46 +66,80 @@ def model_parameters_parser(self, meta_parameters_values=None): # return model_parameters_parser return res["model_parameters"] + def cleanup_meta_parameters(self, meta_parameters_values, meta_parameters): + # clean up meta parameters before saving them. + if not meta_parameters_values: + return {} + + mp = pt_factory("MP", meta_parameters)() + mp.adjust(meta_parameters_values) + return mp.specification(meta_data=False, serializable=True) + def get_inputs(self, meta_parameters_values=None): """ Get cached version of inputs or retrieve new version. """ meta_parameters_values = meta_parameters_values or {} - + self.config = None try: - config = ModelConfig.objects.get( + self.config = ModelConfig.objects.get( project=self.project, model_version=str(self.project.latest_tag), meta_parameters_values=meta_parameters_values, ) - except ModelConfig.DoesNotExist: - success, result = self.compute.submit_job( + print("STATUS", self.config.status) + if self.config.status != "SUCCESS" and not self.config.is_stale(): + print("raise yo") + raise NotReady(self.config) + elif self.config.status != "SUCCESS" and self.config.is_stale(): + raise Stale(self.config) + except (ModelConfig.DoesNotExist, Stale) as e: + response = self.compute.submit_job( project=self.project, task_name=actions.INPUTS, task_kwargs={"meta_param_dict": meta_parameters_values or {}}, + path_prefix="/api/v1/jobs" + if self.project.cluster.version == "v1" + else "", ) + if self.project.cluster.version == "v1" and isinstance( + e, ModelConfig.DoesNotExist + ): + self.config = ModelConfig.objects.create( + project=self.project, + model_version=str(self.project.latest_tag), + meta_parameters_values=meta_parameters_values, + inputs_version="v1", + job_id=response, + status="PENDING", + ) + raise NotReady(self.config) + elif self.project.cluster.version == "v1" and isinstance(e, Stale): + self.config.model_version = str(self.project.latest_tag) + self.config.job_id = response + self.config.status = "PENDING" + self.config.save() + raise NotReady(self.config) + + success, result = response if not success: raise AppError(meta_parameters_values, result["traceback"]) - # clean up meta parameters before saving them. - if meta_parameters_values: - mp = pt_factory("MP", result["meta_parameters"])() - mp.adjust(meta_parameters_values) - save_vals = mp.specification(meta_data=False, serializable=True) - else: - save_vals = {} + save_vals = self.cleanup_meta_parameters( + meta_parameters_values, result["meta_parameters"] + ) - config = ModelConfig.objects.create( + self.config = ModelConfig.objects.create( project=self.project, model_version=str(self.project.latest_tag), meta_parameters_values=save_vals, meta_parameters=result["meta_parameters"], model_parameters=result["model_parameters"], inputs_version="v1", + status="SUCCESS", ) - self.config = config return { - "meta_parameters": config.meta_parameters, - "model_parameters": config.model_parameters, + "meta_parameters": self.config.meta_parameters, + "model_parameters": self.config.model_parameters, } diff --git a/webapp/apps/comp/models.py b/webapp/apps/comp/models.py index 36ec81a8..68c7cd91 100755 --- a/webapp/apps/comp/models.py +++ b/webapp/apps/comp/models.py @@ -77,6 +77,19 @@ class ModelConfig(models.Model): meta_parameters = JSONField(default=dict) model_parameters = JSONField(default=dict) + job_id = models.UUIDField(blank=True, default=None, null=True) + status = models.CharField( + choices=( + ("STARTED", "Started"), + ("PENDING", "Pending"), + ("SUCCESS", "Success"), + ("INVALID", "Invalid"), + ("FAIL", "Fail"), + ("WORKER_FAILURE", "Worker Failure"), + ), + max_length=20, + ) + objects = ModelConfigManager() class Meta: @@ -87,6 +100,12 @@ class Meta: ) ] + def is_stale(self, timeout=10): + return ( + self.status != "SUCCESS" + and (timezone.now() - self.creation_date).total_seconds() > timeout + ) + class Inputs(models.Model): objects: models.Manager diff --git a/webapp/apps/comp/parser.py b/webapp/apps/comp/parser.py index 0f3d76f9..e4584ccf 100755 --- a/webapp/apps/comp/parser.py +++ b/webapp/apps/comp/parser.py @@ -3,7 +3,7 @@ from webapp.apps.comp import actions from webapp.apps.comp.compute import Compute -from webapp.apps.comp.exceptions import AppError +from webapp.apps.comp.exceptions import AppError, NotReady from webapp.apps.comp.models import Inputs ParamData = namedtuple("ParamData", ["name", "data"]) @@ -23,11 +23,18 @@ def __init__( self.valid_meta_params = valid_meta_params for param, value in valid_meta_params.items(): setattr(self, param, value) - defaults = model_parameters.defaults(self.valid_meta_params) - self.grouped_defaults = defaults["model_parameters"] - self.flat_defaults = { - k: v for _, sect in self.grouped_defaults.items() for k, v in sect.items() - } + try: + defaults = model_parameters.defaults(self.valid_meta_params) + except NotReady: + self.grouped_defaults = {} + self.flat_defaults = {} + else: + self.grouped_defaults = defaults["model_parameters"] + self.flat_defaults = { + k: v + for _, sect in self.grouped_defaults.items() + for k, v in sect.items() + } @staticmethod def append_errors_warnings(errors_warnings, append_func, defaults=None): @@ -41,11 +48,11 @@ def append_errors_warnings(errors_warnings, append_func, defaults=None): append_func(param, msg, defaults) def parse_parameters(self): + sects = set(self.grouped_defaults.keys()) | set(self.clean_inputs.keys()) errors_warnings = { - sect: {"errors": {}, "warnings": {}} - for sect in list(self.grouped_defaults) + ["GUI", "API"] + sect: {"errors": {}, "warnings": {}} for sect in sects | {"GUI", "API"} } - adjustment = {sect: {} for sect in self.grouped_defaults} + adjustment = defaultdict(dict) return errors_warnings, adjustment def post(self, errors_warnings, params): @@ -55,7 +62,10 @@ def post(self, errors_warnings, params): "errors_warnings": errors_warnings, } job_id = self.compute.submit_job( - project=self.project, task_name=actions.PARSE, task_kwargs=data + project=self.project, + task_name=actions.PARSE, + task_kwargs=data, + path_prefix="/api/v1/jobs" if self.project.cluster.version == "v1" else "", ) return job_id @@ -67,13 +77,8 @@ class Parser: class APIParser(BaseParser): def parse_parameters(self): errors_warnings, adjustment = super().parse_parameters() - extra_keys = set(self.clean_inputs.keys() - self.grouped_defaults.keys()) - if extra_keys: - errors_warnings["API"]["errors"] = { - "extra_keys": [f"Has extra sections: {' ,'.join(extra_keys)}"] - } - - for sect in adjustment: + sects = set(self.grouped_defaults.keys()) | set(self.clean_inputs.keys()) + for sect in sects: adjustment[sect].update(self.clean_inputs.get(sect, {})) # kick off async parsing diff --git a/webapp/apps/comp/serializers.py b/webapp/apps/comp/serializers.py index e1652abe..1597a8bd 100755 --- a/webapp/apps/comp/serializers.py +++ b/webapp/apps/comp/serializers.py @@ -13,7 +13,7 @@ class OutputsSerializer(serializers.Serializer): job_id = serializers.UUIDField() status = serializers.ChoiceField(choices=(("SUCCESS", "Success"), ("FAIL", "Fail"))) - traceback = serializers.CharField(required=False) + traceback = serializers.CharField(required=False, allow_null=True) model_version = serializers.CharField(required=False) meta = serializers.JSONField() outputs = serializers.JSONField(required=False) @@ -107,12 +107,30 @@ class Meta: ) +class ModelConfigAsyncSerializer(serializers.Serializer): + job_id = serializers.UUIDField(required=False) + status = serializers.ChoiceField( + choices=(("SUCCESS", "Success"), ("FAIL", "Fail")), required=False + ) + outputs = serializers.JSONField(required=False) + + def to_internal_value(self, data): + if "outputs" in data: + data.update(**data.pop("outputs")) + if "task_id" in data: + data["job_id"] = data.pop("task_id") + print(data.keys()) + return super().to_internal_value(data) + + class ModelConfigSerializer(serializers.ModelSerializer): project = serializers.StringRelatedField() class Meta: model = ModelConfig fields = ( + "job_id", + "status", "project", "model_version", "meta_parameters_values", @@ -130,6 +148,14 @@ class Meta: "creation_date", ) + def to_internal_value(self, data): + if "outputs" in data: + data.update(**data.pop("outputs")) + if "task_id" in data: + data["job_id"] = data.pop("task_id") + print(data.keys()) + return super().to_internal_value(data) + class InputsSerializer(serializers.ModelSerializer): """ diff --git a/webapp/apps/comp/tests/test_api_parser.py b/webapp/apps/comp/tests/test_api_parser.py index f1d1861b..27d3de46 100755 --- a/webapp/apps/comp/tests/test_api_parser.py +++ b/webapp/apps/comp/tests/test_api_parser.py @@ -1,3 +1,4 @@ +import pytest from webapp.apps.users.models import Project from webapp.apps.comp.model_parameters import ModelParameters from webapp.apps.comp.ioutils import get_ioutils @@ -35,6 +36,10 @@ def get_inputs(self, meta_parameters=None): assert errors_warnings["GUI"] == exp_errors_warnings +# Opting out of this validation for now. It may not be good to have in the long +# run but in the short term, requiring the model parameters to be loaded from the db +# before running a sim is a bottleneck. +@pytest.mark.xfail def test_api_parser_extra_section(db, get_inputs, valid_meta_params): class MockMp(ModelParameters): def get_inputs(self, meta_parameters=None): diff --git a/webapp/apps/comp/tests/test_asyncviews.py b/webapp/apps/comp/tests/test_asyncviews.py index d4a8f205..d41fee7c 100755 --- a/webapp/apps/comp/tests/test_asyncviews.py +++ b/webapp/apps/comp/tests/test_asyncviews.py @@ -193,14 +193,16 @@ def post_adjustment( adj_resp_data: dict, adj: dict, ) -> Response: - mock.register_uri( - "POST", - f"{self.project.cluster.url}/{self.project}/", - json=lambda request, context: { + def mock_json(request, context): + return { "defaults": defaults_resp_data, "parse": adj_resp_data, "version": {"status": "SUCCESS", "version": "v1"}, - }[request.json()["task_name"]], + "sim": {"task_id": str(uuid.uuid4())}, + }[request.json()["task_name"]] + + mock.register_uri( + "POST", f"{self.project.cluster.url}/{self.project}/", json=mock_json, ) init_resp = self.api_client.post( f"/{self.project}/api/v1/", data=adj, format="json" diff --git a/webapp/apps/comp/views/__init__.py b/webapp/apps/comp/views/__init__.py index 13496c57..2863c72e 100755 --- a/webapp/apps/comp/views/__init__.py +++ b/webapp/apps/comp/views/__init__.py @@ -20,6 +20,7 @@ OutputsAPIView, DetailMyInputsAPIView, MyInputsAPIView, + ModelConfigAPIView, NewSimulationAPIView, AuthorsAPIView, AuthorsDeleteAPIView, diff --git a/webapp/apps/comp/views/api.py b/webapp/apps/comp/views/api.py index 29b2821b..4cd8b921 100755 --- a/webapp/apps/comp/views/api.py +++ b/webapp/apps/comp/views/api.py @@ -25,7 +25,7 @@ import paramtools as pt import cs_storage -from webapp.apps.users.auth import ClusterAuthentication +from webapp.apps.users.auth import ClusterAuthentication, ClientOAuth2Authentication from webapp.apps.users.models import ( Project, Profile, @@ -43,15 +43,23 @@ ForkObjectException, PrivateAppException, PrivateSimException, + NotReady, ) from webapp.apps.comp.ioutils import get_ioutils -from webapp.apps.comp.models import Inputs, Simulation, PendingPermission, ANON_BEFORE +from webapp.apps.comp.models import ( + Inputs, + Simulation, + PendingPermission, + ModelConfig, + ANON_BEFORE, +) from webapp.apps.comp.parser import APIParser from webapp.apps.comp.serializers import ( SimulationSerializer, MiniSimulationSerializer, InputsSerializer, OutputsSerializer, + ModelConfigSerializer, AddAuthorsSerializer, SimAccessSerializer, PendingPermissionSerializer, @@ -80,10 +88,14 @@ def get_inputs(self, kwargs, meta_parameters=None): ioutils = get_ioutils(project) try: defaults = ioutils.model_parameters.defaults(meta_parameters) + except NotReady: + print("NOT READY") + return Response(status=202) except pt.ValidationError as e: return Response(str(e), status=status.HTTP_400_BAD_REQUEST) if "year" in defaults["meta_parameters"]: defaults.update({"extend": True}) + return Response(defaults) def get(self, request, *args, **kwargs): @@ -385,12 +397,14 @@ class OutputsAPIView(RecordOutputsMixin, APIView): authentication_classes = ( ClusterAuthentication, + ClientOAuth2Authentication, # Uncomment to allow token-based authentication for this endpoint. # TokenAuthentication, ) def put(self, request, *args, **kwargs): - print("myoutputs api method=PUT", kwargs) + print("myoutputs api method=PUT", request.user, kwargs) + print("authenticator", request.user, request.successful_authenticator) ser = OutputsSerializer(data=request.data) if ser.is_valid(): data = ser.validated_data @@ -436,18 +450,21 @@ def put(self, request, *args, **kwargs): ) return Response(status=status.HTTP_200_OK) else: + print(f"Data from compute cluster is invalid: {ser.errors}") return Response(ser.errors, status=status.HTTP_400_BAD_REQUEST) class MyInputsAPIView(APIView): authentication_classes = ( ClusterAuthentication, + ClientOAuth2Authentication, # Uncomment to allow token-based authentication for this endpoint. # TokenAuthentication, ) def put(self, request, *args, **kwargs): print("myinputs api method=PUT", kwargs) + print("authenticator", request.user, request.successful_authenticator) ser = InputsSerializer(data=request.data) if ser.is_valid(): data = ser.validated_data @@ -490,6 +507,39 @@ def put(self, request, *args, **kwargs): return Response(ser.errors, status=status.HTTP_400_BAD_REQUEST) +class ModelConfigAPIView(APIView): + authentication_classes = ( + ClusterAuthentication, + ClientOAuth2Authentication, + # Uncomment to allow token-based authentication for this endpoint. + # TokenAuthentication, + ) + + def put(self, request, *args, **kwargs): + print("myinputs api method=PUT", kwargs) + print("authenticator", request.user, request.successful_authenticator) + + ser = ModelConfigSerializer(data=request.data) + if ser.is_valid(): + data = ser.validated_data + model_config = get_object_or_404( + ModelConfig.objects.prefetch_related("project"), job_id=data["job_id"] + ) + if model_config.status in ("PENDING", "INVALID", "FAIL"): + ioutils = get_ioutils(model_config.project) + model_config.meta_parameters_values = ioutils.model_parameters.cleanup_meta_parameters( + model_config.meta_parameters_values, data["meta_parameters"] + ) + model_config.meta_parameters = data["meta_parameters"] + model_config.model_parameters = data["model_parameters"] + model_config.status = data["status"] + model_config.save() + return Response(status=status.HTTP_200_OK) + else: + print("model config put error", ser.errors) + return Response(ser.errors, status=status.HTTP_400_BAD_REQUEST) + + class AuthorsAPIView(RequiresLoginPermissions, GetOutputsObjectMixin, APIView): permission_classes = (StrictRequiresActive,) authentication_classes = ( diff --git a/webapp/apps/comp/views/views.py b/webapp/apps/comp/views/views.py index 28b63dba..a98a819e 100755 --- a/webapp/apps/comp/views/views.py +++ b/webapp/apps/comp/views/views.py @@ -184,7 +184,7 @@ def get(self, request, *args, **kwargs): context["tech"] = project.tech context["object"] = project context["deployment"] = deployment - context["viz_host"] = DEFAULT_VIZ_HOST + context["viz_host"] = project.cluster.viz_host or DEFAULT_VIZ_HOST context["protocol"] = "https" return render(request, self.template_name, context) @@ -212,7 +212,7 @@ def get(self, request, *args, **kwargs): "object": project, "deployment": deployment, "protocol": "https", - "viz_host": DEFAULT_VIZ_HOST, + "viz_host": project.cluster.viz_host or DEFAULT_VIZ_HOST, } response = render(request, self.template_name, context) diff --git a/webapp/apps/conftest.py b/webapp/apps/conftest.py index 79744528..e395f615 100755 --- a/webapp/apps/conftest.py +++ b/webapp/apps/conftest.py @@ -76,6 +76,7 @@ def django_db_setup(django_db_setup, django_db_blocker): service_account=comp_api_user.profile, url="http://scheduler", jwt_secret=cryptkeeper.encrypt(binascii.hexlify(os.urandom(32)).decode()), + version="v0", ) common = { diff --git a/webapp/apps/publish/views.py b/webapp/apps/publish/views.py index d098a0ff..c0be75a0 100644 --- a/webapp/apps/publish/views.py +++ b/webapp/apps/publish/views.py @@ -288,6 +288,7 @@ def post(self, request, *args, **kwargs): tag, _ = Tag.objects.get_or_create( project=project, image_tag=data.get("staging_tag"), + version=data.get("version"), defaults=dict(cpu=project.cpu, memory=project.memory), ) project.staging_tag = tag @@ -298,6 +299,7 @@ def post(self, request, *args, **kwargs): tag, _ = Tag.objects.get_or_create( project=project, image_tag=data.get("latest_tag"), + version=data.get("version"), defaults=dict(cpu=project.cpu, memory=project.memory), ) project.latest_tag = tag diff --git a/webapp/apps/users/api.py b/webapp/apps/users/api.py index f98a21f4..1b15c8b9 100644 --- a/webapp/apps/users/api.py +++ b/webapp/apps/users/api.py @@ -10,10 +10,7 @@ TokenAuthentication, ) -from oauth2_provider.contrib.rest_framework import ( - OAuth2Authentication, - TokenHasReadWriteScope, -) +from oauth2_provider.contrib.rest_framework import OAuth2Authentication from webapp.apps.publish.views import GetProjectMixin from .permissions import StrictRequiresActive diff --git a/webapp/apps/users/auth.py b/webapp/apps/users/auth.py index e303b1ac..997a0dc1 100644 --- a/webapp/apps/users/auth.py +++ b/webapp/apps/users/auth.py @@ -4,6 +4,9 @@ from rest_framework import authentication from rest_framework.exceptions import AuthenticationFailed +from oauth2_provider.contrib.rest_framework import ( + OAuth2Authentication as BaseOAuth2Authentication, +) from webapp.apps.users.models import ( @@ -52,3 +55,33 @@ def authenticate(self, request): raise AuthenticationFailed("No such user") return (cluster.service_account.user, None) + + +class ClientOAuth2Authentication(BaseOAuth2Authentication): + """ + Authenticator that forces request.user to be present even if the + oauth2_provider package doesn't want it to be. + + Works around the change introduced in: + https://github.com/evonove/django-oauth-toolkit/commit/628f9e6ba98007d2940bb1a4c28136c03d81c245 + + Reference: + https://github.com/evonove/django-oauth-toolkit/issues/38 + + """ + + def authenticate(self, request): + super_result = super().authenticate(request) + + if super_result: + # The request was found to be authentic. + user, token = super_result + if ( + user is None + and token.application.authorization_grant_type == "client-credentials" + ): + user = token.application.user + result = user, token + else: + result = super_result + return result diff --git a/webapp/apps/users/migrations/0026_auto_20210515_1512.py b/webapp/apps/users/migrations/0026_auto_20210515_1512.py new file mode 100755 index 00000000..624ab6e6 --- /dev/null +++ b/webapp/apps/users/migrations/0026_auto_20210515_1512.py @@ -0,0 +1,44 @@ +# Generated by Django 3.0.14 on 2021-05-15 15:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0025_project_embed_background_color"), + ] + + operations = [ + migrations.AddField( + model_name="cluster", + name="access_token", + field=models.CharField(max_length=512, null=True), + ), + migrations.AddField( + model_name="cluster", + name="access_token_expires_at", + field=models.DateTimeField(null=True), + ), + migrations.AddField( + model_name="cluster", + name="cluster_password", + field=models.CharField(max_length=512, null=True), + ), + migrations.AddField( + model_name="cluster", + name="version", + field=models.CharField(default="v0", max_length=32), + preserve_default=False, + ), + migrations.AddField( + model_name="cluster", + name="viz_host", + field=models.CharField(max_length=128, null=True), + ), + migrations.AddField( + model_name="tag", + name="version", + field=models.CharField(max_length=255, null=True), + ), + ] diff --git a/webapp/apps/users/models.py b/webapp/apps/users/models.py index 782b57d3..b2b6e2f4 100755 --- a/webapp/apps/users/models.py +++ b/webapp/apps/users/models.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import timedelta +from datetime import timedelta, datetime import json import secrets import uuid @@ -182,28 +182,74 @@ def default(self): return self.get(service_account__user__username=DEFAULT_CLUSTER_USER) +class ClusterLoginException(Exception): + pass + + class Cluster(models.Model): url = models.URLField(max_length=64) - jwt_secret = models.CharField(max_length=512, null=True) service_account = models.OneToOneField( Profile, null=True, on_delete=models.SET_NULL ) created_at = models.DateTimeField(auto_now_add=True) deleted_at = models.DateTimeField(null=True) + # v0 + jwt_secret = models.CharField(max_length=512, null=True) + + # v1 + cluster_password = models.CharField(max_length=512, null=True) + access_token = models.CharField(max_length=512, null=True) + access_token_expires_at = models.DateTimeField(null=True) + + # Make viz host configurable to work with multiple clusters at once. + viz_host = models.CharField(max_length=128, null=True) + + version = models.CharField(null=False, max_length=32) + objects = ClusterManager() - def headers(self): - jwt_token = jwt.encode( - {"username": self.service_account.user.username,}, - cryptkeeper.decrypt(self.jwt_secret), + def ensure_access_token(self): + missing_token = self.access_token is None + is_expired = ( + self.access_token_expires_at is None + or self.access_token_expires_at < (timezone.now() - timedelta(seconds=60)) ) - return { - "Authorization": jwt_token, - "Cluster-User": self.service_account.user.username, - } + print("token is missing", missing_token, "token is expired", is_expired) + if missing_token or is_expired: + resp = requests.post( + f"{self.url}/api/v1/login/access-token", + data={ + "username": str(self.service_account), + "password": self.cluster_password, + }, + ) + if resp.status_code != 200: + raise ClusterLoginException( + f"Expected 200, got {resp.status_code}: {resp.text}" + ) + data = resp.json() + self.access_token = data["access_token"] + self.access_token_expires_at = datetime.fromisoformat(data["expires_at"]) + self.save() + self.refresh_from_db() + + def headers(self): + if self.version == "v0": + jwt_token = jwt.encode( + {"username": self.service_account.user.username,}, + cryptkeeper.decrypt(self.jwt_secret), + ) + return { + "Authorization": jwt_token, + "Cluster-User": self.service_account.user.username, + } + elif self.version == "v1": + self.ensure_access_token() + return {"Authorization": f"Bearer {self.access_token}"} def create_user_in_cluster(self, cs_url): + # only works for v0. resp = requests.post( f"{self.url}/auth/", json={ @@ -219,6 +265,13 @@ def create_user_in_cluster(self, cs_url): raise Exception(f"{resp.status_code} {resp.text}") + @property + def path_prefix(self): + if self.version == "v0": + return "" + else: + return "/api/v1" + class ProjectPermissions: READ = ( @@ -481,21 +534,10 @@ def version(self): return None if self.status != "running": return None - try: - success, result = SyncCompute().submit_job( - project=self, task_name=actions.VERSION, task_kwargs=dict() - ) - if success: - return result["version"] - else: - print(f"error retrieving version for {self}", result) - return None - except Exception as e: - print(f"error retrieving version for {self}", e) - import traceback - - traceback.print_exc() - return None + if self.latest_tag: + return self.latest_tag.version + if self.staging_tag: + return self.staging_tag.version def is_owner(self, user): return user == self.owner.user @@ -641,6 +683,7 @@ class Tag(models.Model): cpu = models.DecimalField(max_digits=5, decimal_places=1, null=True, default=2) memory = models.DecimalField(max_digits=5, decimal_places=1, null=True, default=6) created_at = models.DateTimeField(auto_now_add=True) + version = models.CharField(max_length=255, null=True) def __str__(self): return str(self.image_tag) @@ -763,13 +806,14 @@ def create_deployment(self): self.tag = self.project.latest_tag self.save() + cluster: Cluster = self.project.cluster resp = requests.post( - f"{self.project.cluster.url}/deployments/{self.project}/", + f"{cluster.url}{cluster.path_prefix}/deployments/{self.project}/", json={"deployment_name": self.public_name, "tag": str(self.tag)}, - headers=self.project.cluster.headers(), + headers=cluster.headers(), ) - if resp.status_code == 200: + if resp.status_code in (200, 201): return resp.json() elif resp.status_code == 400: data = resp.json() @@ -779,17 +823,19 @@ def create_deployment(self): raise Exception(f"{resp.status_code} {resp.text}") def get_deployment(self): + cluster: Cluster = self.project.cluster resp = requests.get( - f"{self.project.cluster.url}/deployments/{self.project}/{self.public_name}/", - headers=self.project.cluster.headers(), + f"{cluster.url}{cluster.path_prefix}/deployments/{self.project}/{self.public_name}/", + headers=cluster.headers(), ) assert resp.status_code == 200, f"Got {resp.status_code}, {resp.text}" return resp.json() def delete_deployment(self): + cluster: Cluster = self.project.cluster resp = requests.delete( - f"{self.project.cluster.url}/deployments/{self.project}/{self.public_name}/", - headers=self.project.cluster.headers(), + f"{cluster.url}{cluster.path_prefix}/deployments/{self.project}/{self.public_name}/", + headers=cluster.headers(), ) assert resp.status_code == 200, f"Got {resp.status_code}, {resp.text}" self.deleted_at = timezone.now() diff --git a/webapp/apps/users/serializers.py b/webapp/apps/users/serializers.py index 75548388..b68d1e26 100644 --- a/webapp/apps/users/serializers.py +++ b/webapp/apps/users/serializers.py @@ -158,6 +158,7 @@ class Meta: class TagUpdateSerializer(serializers.Serializer): latest_tag = serializers.CharField(allow_null=True, required=False) staging_tag = serializers.CharField(allow_null=True, required=False) + version = serializers.CharField(allow_null=True, required=False) def validate(self, attrs): if attrs.get("latest_tag") is None and attrs.get("staging_tag") is None: @@ -169,6 +170,7 @@ def validate(self, attrs): class TagSerializer(serializers.ModelSerializer): project = serializers.StringRelatedField() + version = serializers.CharField(allow_null=True, required=False) class Meta: model = Tag @@ -178,6 +180,7 @@ class Meta: "memory", "cpu", "created_at", + "version", ) read_only = ( diff --git a/webapp/urls.py b/webapp/urls.py index 58a99f38..c263b9b5 100644 --- a/webapp/urls.py +++ b/webapp/urls.py @@ -42,6 +42,11 @@ ), path("outputs/api/", compviews.OutputsAPIView.as_view(), name="outputs_api"), path("inputs/api/", compviews.MyInputsAPIView.as_view(), name="myinputs_api"), + path( + "model-config/api/", + compviews.ModelConfigAPIView.as_view(), + name="modelconfig_api", + ), url(r"^rest-auth/", include("rest_auth.urls")), url(r"^rest-auth/registration/", include("rest_auth.registration.urls")), path("api/v1/sims", compviews.UserSimsAPIView.as_view(), name="sim_api"), diff --git a/workers/cs-workers/.helmignore b/workers/cs-workers/.helmignore new file mode 100644 index 00000000..0e8a0eb3 --- /dev/null +++ b/workers/cs-workers/.helmignore @@ -0,0 +1,23 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/workers/cs-workers/Chart.yaml b/workers/cs-workers/Chart.yaml new file mode 100644 index 00000000..d570356a --- /dev/null +++ b/workers/cs-workers/Chart.yaml @@ -0,0 +1,27 @@ +apiVersion: v2 +name: cs-workers +description: A Helm chart for Kubernetes + +# A chart can be either an 'application' or a 'library' chart. +# +# Application charts are a collection of templates that can be packaged into versioned archives +# to be deployed. +# +# Library charts provide useful utilities or functions for the chart developer. They're included as +# a dependency of application charts to inject those utilities and functions into the rendering +# pipeline. Library charts do not define any templates and therefore cannot be deployed. +type: application + +# This is the chart version. This version number should be incremented each time you make changes +# to the chart and its templates, including the app version. +# Versions are expected to follow Semantic Versioning (https://semver.org/) +version: 0.1.0 + +# This is the version number of the application being deployed. This version number should be +# incremented each time you make changes to the application. Versions are not expected to +# follow Semantic Versioning. They should reflect the version the application is using. +appVersion: 1.16.0 +# dependencies: +# - name: traefik +# version: "9.1.1" +# repository: "https://helm.traefik.io/traefik" diff --git a/workers/cs-workers/templates/_helpers.tpl b/workers/cs-workers/templates/_helpers.tpl new file mode 100644 index 00000000..e21e6717 --- /dev/null +++ b/workers/cs-workers/templates/_helpers.tpl @@ -0,0 +1,63 @@ +{{/* vim: set filetype=mustache: */}} +{{/* +Expand the name of the chart. +*/}} +{{- define "cs-workers.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "cs-workers.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "cs-workers.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "cs-workers.labels" -}} +helm.sh/chart: {{ include "cs-workers.chart" . }} +{{ include "cs-workers.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "cs-workers.selectorLabels" -}} +app.kubernetes.io/name: {{ include "cs-workers.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} + +{{/* +Create the name of the service account to use +*/}} +{{- define "cs-workers.serviceAccountName" -}} +{{- if .Values.serviceAccount.create }} +{{- default (include "cs-workers.fullname" .) .Values.serviceAccount.name }} +{{- else }} +{{- default "default" .Values.serviceAccount.name }} +{{- end }} +{{- end }} diff --git a/workers/cs-workers/templates/api-Deployment.yaml b/workers/cs-workers/templates/api-Deployment.yaml new file mode 100755 index 00000000..8129aa27 --- /dev/null +++ b/workers/cs-workers/templates/api-Deployment.yaml @@ -0,0 +1,80 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: api + namespace: {{ .Values.workers_namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: api + template: + metadata: + labels: + app: api + spec: + serviceAccountName: workers-api + containers: + - name: api + image: "{{ .Values.registry }}/{{ .Values.project }}/workers_api:{{ .Values.tag }}" + ports: + - containerPort: 5000 + env: + - name: BUCKET + value: "{{ .Values.bucket }}" + - name: PROJECT + value: "{{ .Values.project }}" + {{ if .Values.workers_api_host }} + - name: WORKERS_API_HOST + value: "{{ .Values.workers_api_host }}" + {{ end }} + - name: VIZ_HOST + value: "{{ .Values.viz_host }}" + - name: API_SECRET_KEY + valueFrom: + secretKeyRef: + name: api-secret + key: API_SECRET_KEY + - name: BACKEND_CORS_ORIGINS + value: '{{ .Values.api.allow_origins | toJson }}' + - name: PROJECT_NAMESPACE + value: '{{ .Values.project_namespace }}' + - name: DB_USER + valueFrom: + secretKeyRef: + name: workers-db-secret + key: USER + - name: DB_PASS + valueFrom: + secretKeyRef: + name: workers-db-secret + key: PASSWORD + - name: DB_NAME + valueFrom: + secretKeyRef: + name: workers-db-secret + key: NAME + - name: DB_HOST + valueFrom: + secretKeyRef: + name: workers-db-secret + key: HOST + resources: + requests: + cpu: 1 + memory: 1G + limits: + cpu: 1 + memory: 2G + + {{ if .Values.db.use_gcp_cloud_proxy }} + - name: cloud-sql-proxy + image: gcr.io/cloudsql-docker/gce-proxy:1.17 + command: + - "/cloud_sql_proxy" + - "-instances={{ .Values.db.gcp_sql_instance_name }}=tcp:5432" + securityContext: + runAsNonRoot: true + {{ end }} + nodeSelector: + component: api diff --git a/workers/cs_workers/templates/services/scheduler-RBAC.yaml b/workers/cs-workers/templates/api-RBAC.yaml similarity index 71% rename from workers/cs_workers/templates/services/scheduler-RBAC.yaml rename to workers/cs-workers/templates/api-RBAC.yaml index 35ecae7d..b912d240 100644 --- a/workers/cs_workers/templates/services/scheduler-RBAC.yaml +++ b/workers/cs-workers/templates/api-RBAC.yaml @@ -1,13 +1,14 @@ apiVersion: v1 kind: ServiceAccount metadata: - name: scheduler + name: workers-api + namespace: {{ .Values.workers_namespace }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: job-admin - namespace: default + namespace: {{ .Values.project_namespace }} rules: - apiGroups: ["batch", "extensions"] resources: ["jobs"] @@ -17,11 +18,11 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: job-admin - namespace: default + namespace: {{ .Values.project_namespace }} subjects: - kind: ServiceAccount - name: scheduler - namespace: default + name: workers-api + namespace: {{ .Values.workers_namespace }} roleRef: kind: Role name: job-admin @@ -31,7 +32,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: viz-admin - namespace: default + namespace: {{ .Values.project_namespace }} rules: - apiGroups: ["apps", "", "traefik.containo.us"] resources: ["deployments", "services", "ingressroutes"] @@ -41,11 +42,11 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: viz-admin - namespace: default + namespace: {{ .Values.project_namespace }} subjects: - kind: ServiceAccount - name: scheduler - namespace: default + name: workers-api + namespace: {{ .Values.workers_namespace }} roleRef: kind: Role name: viz-admin diff --git a/workers/cs-workers/templates/api-Service.yaml b/workers/cs-workers/templates/api-Service.yaml new file mode 100644 index 00000000..bd2faf4e --- /dev/null +++ b/workers/cs-workers/templates/api-Service.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: Service +metadata: + name: api + namespace: {{ .Values.workers_namespace }} +spec: + ports: + - port: 80 + targetPort: 5000 + selector: + app: api + type: LoadBalancer diff --git a/workers/cs-workers/templates/api-ingressroute.yaml b/workers/cs-workers/templates/api-ingressroute.yaml new file mode 100644 index 00000000..187dcf41 --- /dev/null +++ b/workers/cs-workers/templates/api-ingressroute.yaml @@ -0,0 +1,35 @@ +{{ if .Values.workers_api_host }} +apiVersion: traefik.containo.us/v1alpha1 +kind: IngressRoute +metadata: + name: api-tls + namespace: {{ .Values.workers_api_namespace }} +spec: + entryPoints: + - websecure + routes: + - match: Host(`{{ .Values.workers_api_host }}`) + kind: Rule + services: + - name: api + port: 80 + tls: + certResolver: myresolver +--- +apiVersion: traefik.containo.us/v1alpha1 +kind: IngressRoute +metadata: + name: api + namespace: {{ .Values.workers_namespace }} +spec: + entryPoints: + - web + routes: + - match: Host(`{{ .Values.workers_api_host }}`) + kind: Rule + services: + - name: api + port: 80 + tls: + certResolver: myresolver +{{ end }} \ No newline at end of file diff --git a/workers/cs-workers/templates/db-deployment.yaml b/workers/cs-workers/templates/db-deployment.yaml new file mode 100644 index 00000000..5de8196f --- /dev/null +++ b/workers/cs-workers/templates/db-deployment.yaml @@ -0,0 +1,57 @@ +{{ if .Values.db.deploy_db }} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: workers-db + namespace: {{ .Values.workers_namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: workers-db + template: + metadata: + labels: + app: workers-db + spec: + containers: + - name: workers-db + env: + - name: POSTGRES_USER + valueFrom: + secretKeyRef: + name: workers-db-secret + key: USER + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: workers-db-secret + key: PASSWORD + - name: POSTGRES_DB + valueFrom: + secretKeyRef: + name: workers-db-secret + key: NAME + image: postgres:12.4 + ports: + - containerPort: 5432 + resources: + requests: + cpu: 100m + memory: 100Mi + volumeMounts: + {{- range $name, $value := .Values.db.volumeMounts }} + - name: {{ $value.name }} + mountPath: {{ quote $value.mountPath }} + subPath: {{ quote $value.subPath }} + {{- end }} + volumes: + {{- range $name, $value := .Values.db.volumes }} + - name: {{ $value.name }} + hostPath: + path: {{ $value.hostPath.path }} + type: {{ $value.hostPath.type }} + {{- end}} + nodeSelector: + component: web +{{ end }} \ No newline at end of file diff --git a/workers/cs-workers/templates/db-secret.yaml b/workers/cs-workers/templates/db-secret.yaml new file mode 100644 index 00000000..8361a878 --- /dev/null +++ b/workers/cs-workers/templates/db-secret.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: Secret +metadata: + name: workers-db-secret + namespace: {{ .Values.workers_namespace }} +type: Opaque +stringData: + USER: {{ .Values.db.user }} + PASSWORD: {{ .Values.db.password }} + HOST: {{ .Values.db.host }} + NAME: {{ .Values.db.name }} \ No newline at end of file diff --git a/workers/cs-workers/templates/db-service.yaml b/workers/cs-workers/templates/db-service.yaml new file mode 100644 index 00000000..6e8cf665 --- /dev/null +++ b/workers/cs-workers/templates/db-service.yaml @@ -0,0 +1,15 @@ +{{ if .Values.db.deploy_db }} +apiVersion: v1 +kind: Service +metadata: + labels: + app: workers-db + name: workers-db + namespace: {{ .Values.workers_namespace }} +spec: + ports: + - port: 5432 + targetPort: 5432 + selector: + app: workers-db +{{ end }} \ No newline at end of file diff --git a/workers/cs_workers/templates/services/job-cleanup-Job.yaml b/workers/cs-workers/templates/job-cleanup-Job.yaml similarity index 66% rename from workers/cs_workers/templates/services/job-cleanup-Job.yaml rename to workers/cs-workers/templates/job-cleanup-Job.yaml index 72f9325a..5bde05b9 100644 --- a/workers/cs_workers/templates/services/job-cleanup-Job.yaml +++ b/workers/cs-workers/templates/job-cleanup-Job.yaml @@ -2,6 +2,7 @@ apiVersion: batch/v1beta1 kind: CronJob metadata: name: job-cleanup + namespace: {{ .Values.workers_namespace }} spec: schedule: "*/30 * * * *" successfulJobsHistoryLimit: 0 @@ -13,5 +14,5 @@ spec: containers: - name: kubectl-container image: bitnami/kubectl:latest - command: ["sh", "-c", "kubectl delete jobs --field-selector status.successful=1"] + command: ["sh", "-c", "kubectl delete jobs --namespace {{ .Values.project_namespace }} --field-selector status.successful=1"] restartPolicy: Never diff --git a/workers/cs_workers/templates/services/job-cleanup-RBAC.yaml b/workers/cs-workers/templates/job-cleanup-RBAC.yaml similarity index 73% rename from workers/cs_workers/templates/services/job-cleanup-RBAC.yaml rename to workers/cs-workers/templates/job-cleanup-RBAC.yaml index 61ddb5e2..a4336ba4 100644 --- a/workers/cs_workers/templates/services/job-cleanup-RBAC.yaml +++ b/workers/cs-workers/templates/job-cleanup-RBAC.yaml @@ -2,12 +2,13 @@ apiVersion: v1 kind: ServiceAccount metadata: name: job-cleanup + namespace: {{ .Values.workers_namespace }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: job-remove - namespace: default + namespace: {{ .Values.project_namespace }} rules: - apiGroups: ["batch", "extensions"] resources: ["jobs"] @@ -17,11 +18,11 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: job-remove - namespace: default + namespace: {{ .Values.project_namespace }} subjects: - kind: ServiceAccount name: job-cleanup - namespace: default + namespace: {{ .Values.workers_namespace }} roleRef: kind: Role name: job-remove diff --git a/workers/cs-workers/templates/outputs-processor-Deployment.yaml b/workers/cs-workers/templates/outputs-processor-Deployment.yaml new file mode 100755 index 00000000..f7baf308 --- /dev/null +++ b/workers/cs-workers/templates/outputs-processor-Deployment.yaml @@ -0,0 +1,44 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: outputs-processor + namespace: {{ .Values.workers_namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: outputs-processor + template: + metadata: + labels: + app: outputs-processor + spec: + serviceAccountName: outputs-processor + containers: + - name: outputs-processor + image: "{{ .Values.registry }}/{{ .Values.project }}/outputs_processor:{{ .Values.tag }}" + ports: + - containerPort: 5000 + env: + - name: BUCKET + value: {{ .Values.bucket }} + - name: PROJECT + value: {{ .Values.project }} + - name: REDIS_HOST + value: {{ .Values.redis.host }} + - name: REDIS_PORT + value: "{{ .Values.redis.port }}" + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: workers-redis-secret + key: PASSWORD + resources: + requests: + cpu: 1 + memory: 1G + limits: + cpu: 1 + memory: 2G + nodeSelector: + component: api diff --git a/workers/cs_workers/templates/services/outputs-processor-ServiceAccount.yaml b/workers/cs-workers/templates/outputs-processor-RBAC.yaml similarity index 61% rename from workers/cs_workers/templates/services/outputs-processor-ServiceAccount.yaml rename to workers/cs-workers/templates/outputs-processor-RBAC.yaml index b78a435d..dd3f5e0d 100644 --- a/workers/cs_workers/templates/services/outputs-processor-ServiceAccount.yaml +++ b/workers/cs-workers/templates/outputs-processor-RBAC.yaml @@ -2,3 +2,4 @@ apiVersion: v1 kind: ServiceAccount metadata: name: outputs-processor + namespace: {{ .Values.workers_namespace }} diff --git a/workers/cs_workers/templates/services/outputs-processor-Service.yaml b/workers/cs-workers/templates/outputs-processor-Service.yaml similarity index 58% rename from workers/cs_workers/templates/services/outputs-processor-Service.yaml rename to workers/cs-workers/templates/outputs-processor-Service.yaml index 2edac636..4ac71616 100644 --- a/workers/cs_workers/templates/services/outputs-processor-Service.yaml +++ b/workers/cs-workers/templates/outputs-processor-Service.yaml @@ -2,9 +2,10 @@ apiVersion: v1 kind: Service metadata: name: outputs-processor + namespace: {{ .Values.workers_namespace }} spec: ports: - - port: 80 - targetPort: 8888 + - port: 80 + targetPort: 5000 selector: app: outputs-processor diff --git a/workers/cs-workers/templates/project_namespace.yaml b/workers/cs-workers/templates/project_namespace.yaml new file mode 100644 index 00000000..bb37da36 --- /dev/null +++ b/workers/cs-workers/templates/project_namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: {{ .Values.project_namespace }} \ No newline at end of file diff --git a/workers/cs-workers/templates/redis-master-Deployment.yaml b/workers/cs-workers/templates/redis-master-Deployment.yaml new file mode 100644 index 00000000..bd511e2b --- /dev/null +++ b/workers/cs-workers/templates/redis-master-Deployment.yaml @@ -0,0 +1,53 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app: redis + name: redis-master + namespace: {{ .Values.workers_namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: redis + role: master + tier: backend + template: + metadata: + labels: + app: redis + role: master + tier: backend + spec: + containers: + - env: [] + command: ["redis-server", "--appendonly", "yes"] + image: redis:6.2.1 + name: master + ports: + - containerPort: 6379 + resources: + requests: + cpu: 100m + memory: 100Mi + volumeMounts: + {{- range $name, $value := .Values.redis.volumeMounts }} + - name: {{ $value.name }} + mountPath: {{ quote $value.mountPath }} + {{- end }} + volumes: + {{- range $name, $value := .Values.redis.volumes }} + - name: {{ $value.name }} + {{ if $value.hostPath }} + hostPath: + path: {{ $value.hostPath.path }} + type: {{ $value.hostPath.type }} + {{ end }} + {{ if $value.gcePersistentDisk }} + gcePersistentDisk: + pdName: {{ $value.gcePersistentDisk.pdName }} + fsType: {{ $value.gcePersistentDisk.fsType }} + {{ end }} + {{- end}} + nodeSelector: + component: api diff --git a/workers/cs_workers/templates/services/redis-master-Service.yaml b/workers/cs-workers/templates/redis-master-Service.yaml similarity index 82% rename from workers/cs_workers/templates/services/redis-master-Service.yaml rename to workers/cs-workers/templates/redis-master-Service.yaml index 04af2120..e5e838a4 100644 --- a/workers/cs_workers/templates/services/redis-master-Service.yaml +++ b/workers/cs-workers/templates/redis-master-Service.yaml @@ -6,6 +6,7 @@ metadata: role: master tier: backend name: redis-master + namespace: {{ .Values.workers_namespace }} spec: ports: - port: 6379 diff --git a/workers/cs-workers/templates/redis-secret.yaml b/workers/cs-workers/templates/redis-secret.yaml new file mode 100644 index 00000000..3cbf24a3 --- /dev/null +++ b/workers/cs-workers/templates/redis-secret.yaml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: workers-redis-secret + namespace: {{ .Values.workers_namespace }} +type: Opaque +stringData: + PASSWORD: {{ .Values.redis.password }} diff --git a/workers/cs-workers/templates/rq-Deployment.yaml b/workers/cs-workers/templates/rq-Deployment.yaml new file mode 100644 index 00000000..7dd948a7 --- /dev/null +++ b/workers/cs-workers/templates/rq-Deployment.yaml @@ -0,0 +1,44 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: rq-worker-outputs + namespace: {{ .Values.workers_namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: rq-worker-outputs + template: + metadata: + labels: + app: rq-worker-outputs + spec: + serviceAccountName: rq-worker + containers: + - name: rq-worker-outputs + command: + ["rq", "worker", "--with-scheduler", "-c", "cs_workers.services.rq_settings"] + image: "{{ .Values.registry }}/{{ .Values.project }}/outputs_processor:{{ .Values.tag }}" + env: + - name: BUCKET + value: {{ .Values.bucket }} + - name: PROJECT + value: {{ .Values.project }} + - name: REDIS_HOST + value: {{ .Values.redis.host }} + - name: REDIS_PORT + value: "{{ .Values.redis.port }}" + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: workers-redis-secret + key: PASSWORD + resources: + requests: + cpu: 1 + memory: 1G + limits: + cpu: 1 + memory: 2G + nodeSelector: + component: api diff --git a/workers/cs-workers/templates/rq-RBAC.yaml b/workers/cs-workers/templates/rq-RBAC.yaml new file mode 100644 index 00000000..da48fd78 --- /dev/null +++ b/workers/cs-workers/templates/rq-RBAC.yaml @@ -0,0 +1,5 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: rq-worker + namespace: {{ .Values.workers_namespace }} diff --git a/workers/cs-workers/templates/workers-api-secret.yaml b/workers/cs-workers/templates/workers-api-secret.yaml new file mode 100644 index 00000000..0e126fdc --- /dev/null +++ b/workers/cs-workers/templates/workers-api-secret.yaml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: api-secret + namespace: {{ .Values.workers_namespace }} +type: Opaque +stringData: + API_SECRET_KEY: {{ .Values.api.secret_key }} diff --git a/workers/cs-workers/templates/workers_namespace.yaml b/workers/cs-workers/templates/workers_namespace.yaml new file mode 100644 index 00000000..7219c80b --- /dev/null +++ b/workers/cs-workers/templates/workers_namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: {{ .Values.workers_namespace }} \ No newline at end of file diff --git a/workers/cs-workers/values.yaml b/workers/cs-workers/values.yaml new file mode 100644 index 00000000..3d3d259a --- /dev/null +++ b/workers/cs-workers/values.yaml @@ -0,0 +1,62 @@ +# Default values for cs-workers. +# This is a YAML-formatted file. +# Declare variables to be passed into your templates. + +replicaCount: 1 +bucket: cs-outputs-dev-private + +viz_host: devviz.compute.studio +# image: +project: project +registry: gcr.io +# Overrides the image tag whose default is the chart appVersion. +tag: tag + +workers_namespace: workers +project_namespace: projects + +api: + secret_key: abc123 + allow_origins: + - "http://10.0.0.137:5000" + - "http://localhost:5000" + - "http://api.workers.svc.cluster.local" +redis: + host: "redis-master" + port: "6379" + password: "" + volumes: + - name: redis-volume-v1 + hostPath: + path: /redis-queue-data + type: Directory + volumeMounts: + - mountPath: /data + name: redis-volume-v1 + +db: + name: cluster_db + user: postgres + password: password + host: workers-db + + deploy_db: true + # use_gcp_cloud_proxy: false + # gcp_sql_instance_name: null + + volumes: + - name: workers-db-volume + hostPath: + path: /workers-db-data + type: DirectoryOrCreate + + volumeMounts: + - mountPath: /var/lib/postgresql/data/ + name: workers-db-volume + subPath: postgres +# credsVolumes: +# env_var: GOOGLE_APPLICATION_CREDENTIALS +# path: /google-creds.json +# volumes: +# - name: google-creds-volume +# hostPath diff --git a/workers/cs_workers/cli.py b/workers/cs_workers/cli.py index 41d0fbc5..9476e2af 100644 --- a/workers/cs_workers/cli.py +++ b/workers/cs_workers/cli.py @@ -8,11 +8,13 @@ from cs_deploy.config import workers_config as config import cs_workers.services.manage -import cs_workers.services.scheduler + +# import cs_workers.services.scheduler import cs_workers.services.outputs_processor import cs_workers.models.manage import cs_workers.models.executors.job -import cs_workers.models.executors.api_task + +# import cs_workers.models.executors.api_task import cs_workers.models.executors.server @@ -31,11 +33,11 @@ def cli(subparsers: argparse._SubParsersAction = None): sub_parsers = parser.add_subparsers() cs_workers.services.manage.cli(sub_parsers, config=config) - cs_workers.services.scheduler.cli(sub_parsers) - cs_workers.services.outputs_processor.cli(sub_parsers) + # cs_workers.services.scheduler.cli(sub_parsers) + # cs_workers.services.outputs_processor.cli(sub_parsers) cs_workers.models.manage.cli(sub_parsers) cs_workers.models.executors.job.cli(sub_parsers) - cs_workers.models.executors.api_task.cli(sub_parsers) + # cs_workers.models.executors.api_task.cli(sub_parsers) cs_workers.models.executors.server.cli(sub_parsers) if subparsers is None: diff --git a/workers/cs_workers/dockerfiles/Dockerfile.outputs_processor b/workers/cs_workers/dockerfiles/Dockerfile.outputs_processor index 862e1719..c2a625f1 100755 --- a/workers/cs_workers/dockerfiles/Dockerfile.outputs_processor +++ b/workers/cs_workers/dockerfiles/Dockerfile.outputs_processor @@ -12,8 +12,8 @@ RUN apt-get update && \ RUN pip install -r requirements.txt && \ - pip install pyppeteer2 && \ - conda install -c conda-forge jinja2 bokeh tornado dask && \ + pip install pyppeteer2 rq && \ + conda install -c conda-forge jinja2 bokeh && \ pyppeteer-install RUN mkdir /home/cs_workers @@ -27,8 +27,8 @@ RUN pip install -e ./secrets COPY deploy /home/deploy RUN pip install -e ./deploy -WORKDIR /home +WORKDIR /home/cs_workers ENV PYTHONUNBUFFERED 1 -CMD ["csw", "outputs", "--start"] \ No newline at end of file +CMD ["uvicorn", "services.outputs_processor:app", "--host", "0.0.0.0", "--port", "5000", "--reload"] \ No newline at end of file diff --git a/workers/cs_workers/dockerfiles/Dockerfile.scheduler b/workers/cs_workers/dockerfiles/Dockerfile.workers_api similarity index 67% rename from workers/cs_workers/dockerfiles/Dockerfile.scheduler rename to workers/cs_workers/dockerfiles/Dockerfile.workers_api index 6d13c983..fc022698 100755 --- a/workers/cs_workers/dockerfiles/Dockerfile.scheduler +++ b/workers/cs_workers/dockerfiles/Dockerfile.workers_api @@ -12,7 +12,7 @@ EXPOSE 80 EXPOSE 8888 RUN pip install -r requirements.txt && \ - conda install tornado + pip install python-multipart sqlalchemy python-jose[cryptography] psycopg2-binary passlib[bcrypt] RUN mkdir /home/cs_workers COPY workers/cs_workers /home/cs_workers @@ -29,4 +29,6 @@ WORKDIR /home ENV PYTHONUNBUFFERED 1 -CMD ["csw", "scheduler", "--start"] \ No newline at end of file +WORKDIR /home/cs_workers/services/ + +CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "5000", "--reload"] \ No newline at end of file diff --git a/workers/cs_workers/models/clients/api_task.py b/workers/cs_workers/models/clients/api_task.py deleted file mode 100644 index 0f1a529f..00000000 --- a/workers/cs_workers/models/clients/api_task.py +++ /dev/null @@ -1,26 +0,0 @@ -import httpx - - -from cs_workers.utils import clean - - -class APITask: - def __init__(self, owner, title, task_id, task_name, **task_kwargs): - self.owner = owner - self.title = title - self.task_id = task_id - self.task_name = task_name - self.task_kwargs = task_kwargs - - async def create(self, asynchronous=False): - method = "async" if asynchronous else "sync" - async with httpx.AsyncClient() as client: - resp = await client.post( - f"http://{clean(self.owner)}-{clean(self.title)}-api-task/{method}/", - json={ - "task_id": self.task_id, - "task_name": self.task_name, - "task_kwargs": self.task_kwargs, - }, - ) - return resp diff --git a/workers/cs_workers/models/clients/job.py b/workers/cs_workers/models/clients/job.py index de36f7c9..87af101c 100644 --- a/workers/cs_workers/models/clients/job.py +++ b/workers/cs_workers/models/clients/job.py @@ -7,6 +7,7 @@ from kubernetes import client as kclient, config as kconfig from cs_workers.utils import clean, redis_conn_from_env +from cs_workers.models.secrets import ModelSecrets redis_conn = dict( username="scheduler", @@ -23,60 +24,60 @@ def __init__( title, tag, model_config, - job_id=None, - job_kwargs=None, + job_id, + callback_url, + route_name="sim", cr="gcr.io", incluster=True, - rclient=None, quiet=True, + namespace="default", ): self.project = project self.owner = owner self.title = title self.tag = tag self.model_config = model_config + print(self.model_config) self.cr = cr self.quiet = quiet + self.namespace = namespace self.incluster = incluster - if rclient is None: - self.rclient = redis.Redis(**redis_conn) - else: - self.rclient = rclient if self.incluster: kconfig.load_incluster_config() else: kconfig.load_kube_config() self.api_client = kclient.BatchV1Api() - self.job = self.configure(owner, title, tag, job_id) - self.save_job_kwargs(self.job_id, job_kwargs) + self.job = self.configure(owner, title, tag, job_id, callback_url, route_name) def env(self, owner, title, config): safeowner = clean(owner) safetitle = clean(title) envs = [ - kclient.V1EnvVar("OWNER", config["owner"]), - kclient.V1EnvVar("TITLE", config["title"]), + kclient.V1EnvVar("OWNER", owner), + kclient.V1EnvVar("TITLE", title), kclient.V1EnvVar("EXP_TASK_TIME", str(config["exp_task_time"])), ] - for sec in [ - "BUCKET", - "REDIS_HOST", - "REDIS_PORT", - "REDIS_EXECUTOR_PW", - ]: - envs.append( - kclient.V1EnvVar( - sec, - value_from=kclient.V1EnvVarSource( - secret_key_ref=( - kclient.V1SecretKeySelector(key=sec, name="worker-secret") - ) - ), - ) - ) - - for secret in self.model_config._list_secrets(config): + # for sec in [ + # "BUCKET", + # "REDIS_HOST", + # "REDIS_PORT", + # "REDIS_EXECUTOR_PW", + # ]: + # envs.append( + # kclient.V1EnvVar( + # sec, + # value_from=kclient.V1EnvVarSource( + # secret_key_ref=( + # kclient.V1SecretKeySelector(key=sec, name="worker-secret") + # ) + # ), + # ) + # ) + + for secret in ModelSecrets( + owner=owner, title=title, project=self.project + ).list(): envs.append( kclient.V1EnvVar( name=secret, @@ -91,13 +92,10 @@ def env(self, owner, title, config): ) return envs - def configure(self, owner, title, tag, job_id=None): - if job_id is None: - job_id = str(uuid.uuid4()) - else: - job_id = str(job_id) + def configure(self, owner, title, tag, job_id, callback_url, route_name): + job_id = str(job_id) - config = self.model_config.projects()[f"{owner}/{title}"] + config = self.model_config safeowner = clean(owner) safetitle = clean(title) @@ -105,7 +103,14 @@ def configure(self, owner, title, tag, job_id=None): container = kclient.V1Container( name=job_id, image=f"{self.cr}/{self.project}/{safeowner}_{safetitle}_tasks:{tag}", - command=["csw", "job", "--job-id", job_id, "--route-name", "sim"], + command=[ + "csw", + "job", + "--callback-url", + callback_url, + "--route-name", + route_name, + ], env=self.env(owner, title, config), resources=kclient.V1ResourceRequirements(**config["resources"]), ) @@ -137,18 +142,15 @@ def configure(self, owner, title, tag, job_id=None): return job - def save_job_kwargs(self, job_id, job_kwargs): - if not job_id.startswith("job-"): - job_id = f"job-{job_id}" - self.rclient.set(job_id, json.dumps(job_kwargs)) - def create(self): - return self.api_client.create_namespaced_job(body=self.job, namespace="default") + return self.api_client.create_namespaced_job( + body=self.job, namespace=self.namespace + ) def delete(self): return self.api_client.delete_namespaced_job( name=self.job.metadata.name, - namespace="default", + namespace=self.namespace, body=kclient.V1DeleteOptions(), ) diff --git a/workers/cs_workers/models/clients/server.py b/workers/cs_workers/models/clients/server.py index c1dc33f6..271a0533 100644 --- a/workers/cs_workers/models/clients/server.py +++ b/workers/cs_workers/models/clients/server.py @@ -9,6 +9,7 @@ from cs_workers.utils import clean, redis_conn_from_env from cs_workers.config import ModelConfig from cs_workers.ingressroute import IngressRouteApi, ingressroute_template +from cs_workers.models.secrets import ModelSecrets PORT = 8010 @@ -71,7 +72,9 @@ def env(self, owner, title, deployment_name, config): kclient.V1EnvVar("TITLE", config["title"]), ] - for secret in self.model_config._list_secrets(config): + for secret in ModelSecrets( + owner=owner, title=title, project=self.project + ).list(): envs.append( kclient.V1EnvVar( name=secret, @@ -94,7 +97,7 @@ def env(self, owner, title, deployment_name, config): return envs def configure(self): - config = self.model_config.projects()[f"{self.owner}/{self.title}"] + config = self.model_config safeowner = clean(self.owner) safetitle = clean(self.title) app_name = f"{safeowner}-{safetitle}" @@ -270,14 +273,19 @@ def create(self): deployment_resp = self.deployment_api_client.create_namespaced_deployment( namespace=self.namespace, body=self.deployment ) + print("dep resp") + print(deployment_resp) service_resp = self.service_api_client.create_namespaced_service( namespace=self.namespace, body=self.service ) - + print("svc resp") + print(service_resp) ingressroute_resp = self.ir_api_client.create_namespaced_ingressroute( namespace=self.namespace, body=self.ingressroute ) + print("ir resp") + print(ingressroute_resp) return deployment_resp, service_resp, ingressroute_resp @@ -321,21 +329,3 @@ def full_name(self): safeowner = clean(self.owner) safetitle = clean(self.title) return f"{safeowner}-{safetitle}-{self.deployment_name}" - - -if __name__ == "__main__": - server = Server( - project="cs-workers-dev", - owner="hdoupe", - title="ccc-widget", - tag="fix-iframe-link3", - deployment_name="hankdoupe", - model_config=ModelConfig("cs-workers-dev", "https://dev.compute.studio"), - callable_name="dash", - incluster=False, - quiet=True, - ) - server - server.configure() - server.create() - # server.delete() diff --git a/workers/cs_workers/models/executors/job.py b/workers/cs_workers/models/executors/job.py index a0ad740e..6bf6b44c 100644 --- a/workers/cs_workers/models/executors/job.py +++ b/workers/cs_workers/models/executors/job.py @@ -1,50 +1,46 @@ import argparse import asyncio -import functools -import json -import os -import redis import httpx import cs_storage -from cs_workers.models.executors.task_wrapper import async_task_wrapper +from cs_workers.models.executors.task_wrapper import task_wrapper - -def sim_handler(task_id, meta_param_dict, adjustment): +try: from cs_config import functions +except ImportError: + functions = None + + +def version(**task_kwargs): + return {"version": functions.get_version()} + + +def defaults(meta_param_dict=None, **task_kwargs): + return functions.get_inputs(meta_param_dict) + +def parse(meta_param_dict, adjustment, errors_warnings): + return functions.validate_inputs(meta_param_dict, adjustment, errors_warnings) + + +def sim(meta_param_dict, adjustment): outputs = functions.run_model(meta_param_dict, adjustment) print("got result") - outputs = cs_storage.serialize_to_json(outputs) - print("storing results") - for i in range(3): - try: - resp = httpx.post( - "http://outputs-processor/write/", - json={"task_id": task_id, "outputs": outputs}, - timeout=120.0, - ) - break - except Exception as e: - print(i, e) - - print("got resp", resp.status_code, resp.url) - assert resp.status_code == 200, f"Got code: {resp.status_code}" - return resp.json() + return cs_storage.serialize_to_json(outputs) -routes = {"sim": sim_handler} +routes = {"version": version, "defaults": defaults, "parse": parse, "sim": sim} def main(args: argparse.Namespace): asyncio.run( - async_task_wrapper(args.job_id, args.route_name, routes[args.route_name]) + task_wrapper(args.callback_url, args.route_name, routes[args.route_name]) ) def cli(subparsers: argparse._SubParsersAction): parser = subparsers.add_parser("job", description="CLI for C/S jobs.") - parser.add_argument("--job-id", "-t", required=True) - parser.add_argument("--route-name", "-r", required=True) + parser.add_argument("--callback-url", required=True) + parser.add_argument("--route-name", required=True) parser.set_defaults(func=main) diff --git a/workers/cs_workers/models/executors/task_wrapper.py b/workers/cs_workers/models/executors/task_wrapper.py index 8dccae06..32a3adca 100644 --- a/workers/cs_workers/models/executors/task_wrapper.py +++ b/workers/cs_workers/models/executors/task_wrapper.py @@ -1,72 +1,60 @@ -import functools -import json import os -import re import time import traceback -import redis import httpx -import cs_storage - -from cs_workers.utils import redis_conn_from_env - - -redis_conn = dict( - username="executor", - password=os.environ.get("REDIS_EXECUTOR_PW"), - **redis_conn_from_env(), -) try: from cs_config import functions except ImportError as ie: - # if os.environ.get("IS_FLASK", "False") == "True": - # functions = None - # else: - # raise ie pass -async def sync_task_wrapper(task_id, task_name, func, task_kwargs=None): - print("sync task", task_id, func, task_kwargs) - start = time.time() - traceback_str = None - res = {} - try: - outputs = func(task_id, **(task_kwargs or {})) - res.update(outputs) - except Exception: - traceback_str = traceback.format_exc() - finish = time.time() - if "meta" not in res: - res["meta"] = {} - res["meta"]["task_times"] = [finish - start] - if traceback_str is None: - res["status"] = "SUCCESS" +async def get_task_kwargs(callback_url, retries=5): + """ + Retrieve task kwargs from callback_url. + + Returns + ------- + resp: httpx.Response + """ + job_token = os.environ.get("JOB_TOKEN", None) + if job_token is not None: + headers = {"Authorization": f"Token {job_token}"} else: - res["status"] = "FAIL" - res["traceback"] = traceback_str - return res + headers = None + + for retry in range(0, retries + 1): + try: + async with httpx.AsyncClient() as client: + resp = await client.get(callback_url, headers=headers) + resp.raise_for_status() + return resp + except Exception as e: + print(f"Exception when retrieving value from callback url: {callback_url}") + print(f"Exception: {e}") + if retry >= retries: + raise e + wait_time = 2 ** retry + print(f"Trying again in {wait_time} seconds.") + time.sleep(wait_time) -async def async_task_wrapper(task_id, task_name, func, task_kwargs=None): - print("async task", task_id, func, task_kwargs) +async def task_wrapper(callback_url, task_name, func, task_kwargs=None): + print("async task", callback_url, func, task_kwargs) start = time.time() traceback_str = None - res = {"task_id": task_id} + res = { + "task_name": task_name, + } try: if task_kwargs is None: - if not task_id.startswith("job-"): - _task_id = f"job-{task_id}" - else: - _task_id = task_id - with redis.Redis(**redis_conn) as rclient: - task_kwargs = rclient.get(_task_id) - if task_kwargs is not None: - task_kwargs = json.loads(task_kwargs.decode()) - outputs = func(task_id, **(task_kwargs or {})) + print("getting task_kwargs") + resp = await get_task_kwargs(callback_url) + task_kwargs = resp.json()["inputs"] + print("got task_kwargs", task_kwargs) + outputs = func(**(task_kwargs or {})) res.update( { "model_version": functions.get_version(), @@ -76,22 +64,24 @@ async def async_task_wrapper(task_id, task_name, func, task_kwargs=None): ) except Exception: traceback_str = traceback.format_exc() + finish = time.time() + if "meta" not in res: res["meta"] = {} res["meta"]["task_times"] = [finish - start] + if traceback_str is None: res["status"] = "SUCCESS" else: res["status"] = "FAIL" res["traceback"] = traceback_str + print("saving results...") async with httpx.AsyncClient() as client: - resp = await client.post( - "http://outputs-processor/push/", - json={"task_name": task_name, "result": res}, - ) + resp = await client.post(callback_url, json=res, timeout=120) + print("resp", resp.status_code, resp.url) - assert resp.status_code == 200, f"Got code: {resp.status_code}" + assert resp.status_code in (200, 201), f"Got code: {resp.status_code} ({resp.text})" return res diff --git a/workers/cs_workers/models/manage.py b/workers/cs_workers/models/manage.py index 41a9e7c8..7aa853d2 100644 --- a/workers/cs_workers/models/manage.py +++ b/workers/cs_workers/models/manage.py @@ -24,6 +24,13 @@ BASE_PATH = CURR_PATH / ".." +def strip_secrets(line, secrets): + line = line.decode() + for name, value in secrets.items(): + line = line.replace(name, "******").replace(value, "******") + return line.strip("\n") + + class BaseManager: def __init__(self, project, cs_url, cs_api_token): self.project = project @@ -152,9 +159,12 @@ def stage(self): def promote(self): self.apply_method_to_apps(method=self.promote_app) + def version(self): + self.apply_method_to_apps(method=self.get_version) + def write_app_config(self): self.apply_method_to_apps(method=self.write_secrets) - self.apply_method_to_apps(method=self._write_api_task) + # self.apply_method_to_apps(method=self._write_api_task) def apply_method_to_apps(self, method): """ @@ -266,12 +276,6 @@ def test_app_image(self, app): try: - def strip_secrets(line, secrets): - line = line.decode() - for name, value in secrets.items(): - line = line.replace(name, "******").replace(value, "******") - return line.strip("\n") - def stream_logs(container): for line in container.logs(stream=True): print(strip_secrets(line, secrets)) @@ -357,10 +361,41 @@ def push_app_image(self, app): run(f"{cmd_prefix} {self.cr}/{self.project}/{img_name}:{tag}") + def get_version(self, app, print_stdout=True): + safeowner = clean(app["owner"]) + safetitle = clean(app["title"]) + img_name = f"{safeowner}_{safetitle}_tasks" + + app_version = None + if app["tech"] == "python-paramtools": + secrets = self.config._list_secrets(app) + client = docker.from_env() + container = client.containers.run( + f"{img_name}:{self.tag}", + [ + "python", + "-c", + "from cs_config import functions; print(functions.get_version())", + ], + environment=secrets, + detach=True, + ports=None, + ) + logs = [] + for line in container.logs(stream=True): + logs.append(strip_secrets(line, secrets)) + app_version = logs[0] if logs else None + + if app_version and print_stdout: + sys.stdout.write(app_version) + + return app_version + def stage_app(self, app): + app_version = self.get_version(app, print_stdout=False) resp = httpx.post( f"{self.config.cs_url}/apps/api/v1/{app['owner']}/{app['title']}/tags/", - json={"staging_tag": self.staging_tag}, + json={"staging_tag": self.staging_tag, "version": app_version}, headers={"Authorization": f"Token {self.cs_api_token}"}, ) assert ( @@ -378,9 +413,16 @@ def promote_app(self, app): resp.status_code == 200 ), f"Got: {resp.url} {resp.status_code} {resp.text}" staging_tag = resp.json()["staging_tag"]["image_tag"] + app_version = resp.json()["staging_tag"]["version"] + if app_version is None: + app_version = self.get_version(app, print_stdout=False) resp = httpx.post( f"{self.config.cs_url}/apps/api/v1/{app['owner']}/{app['title']}/tags/", - json={"latest_tag": staging_tag or self.tag, "staging_tag": None}, + json={ + "latest_tag": staging_tag or self.tag, + "staging_tag": None, + "version": app_version, + }, headers={"Authorization": f"Token {self.cs_api_token}"}, ) assert ( @@ -584,6 +626,20 @@ def stage(args: argparse.Namespace): manager.stage() +def version(args: argparse.Namespace): + manager = Manager( + project=args.project, + tag=args.tag, + cs_url=getattr(args, "cs_url", None) or workers_config["CS_URL"], + cs_api_token=getattr(args, "cs_api_token", None), + models=args.names, + base_branch=args.base_branch, + cr=args.cr, + staging_tag=getattr(args, "staging_tag", None), + ) + manager.version() + + def cli(subparsers: argparse._SubParsersAction): parser = subparsers.add_parser( "models", description="Deploy and manage models on C/S compute cluster." @@ -620,6 +676,9 @@ def cli(subparsers: argparse._SubParsersAction): promote_parser = model_subparsers.add_parser("promote") promote_parser.set_defaults(func=promote) + version_parser = model_subparsers.add_parser("version") + version_parser.set_defaults(func=version) + secrets.cli(model_subparsers) parser.set_defaults(func=lambda args: print(args)) diff --git a/workers/cs_workers/services/api/__init__.py b/workers/cs_workers/services/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/workers/cs_workers/services/api/alembic.ini b/workers/cs_workers/services/api/alembic.ini new file mode 100644 index 00000000..cf785fc1 --- /dev/null +++ b/workers/cs_workers/services/api/alembic.ini @@ -0,0 +1,85 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +timezone = UTC + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat alembic/versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +; sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +hooks=black +black.type=console_scripts +black.entrypoint=black +black.options=-l 90 + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/workers/cs_workers/services/api/alembic/README b/workers/cs_workers/services/api/alembic/README new file mode 100644 index 00000000..98e4f9c4 --- /dev/null +++ b/workers/cs_workers/services/api/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/workers/cs_workers/services/api/alembic/env.py b/workers/cs_workers/services/api/alembic/env.py new file mode 100644 index 00000000..317aee3b --- /dev/null +++ b/workers/cs_workers/services/api/alembic/env.py @@ -0,0 +1,86 @@ +import os + +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +from cs_workers.services.api import database, models + +target_metadata = database.Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_url(): + DB_HOST = os.environ.get("DB_HOST", "127.0.0.1") + DB_USER = os.environ.get("DB_USER", "postgres") + DB_PASS = os.environ.get("DB_PASS", "") + DB_NAME = os.environ.get("DB_NAME", "") + DB_PORT = os.environ.get("DB_PORT", "5432") + + return f"postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = get_url() + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True, compare_type=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + configuration = config.get_section(config.config_ini_section) + configuration["sqlalchemy.url"] = get_url() + connectable = engine_from_config( + configuration, prefix="sqlalchemy.", poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata, compare_type=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/workers/cs_workers/services/api/alembic/script.py.mako b/workers/cs_workers/services/api/alembic/script.py.mako new file mode 100644 index 00000000..2c015630 --- /dev/null +++ b/workers/cs_workers/services/api/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/workers/cs_workers/services/api/alembic/versions/07a0d0f7e1b4_init.py b/workers/cs_workers/services/api/alembic/versions/07a0d0f7e1b4_init.py new file mode 100644 index 00000000..d7efbd4e --- /dev/null +++ b/workers/cs_workers/services/api/alembic/versions/07a0d0f7e1b4_init.py @@ -0,0 +1,88 @@ +"""Init + +Revision ID: 07a0d0f7e1b4 +Revises: +Create Date: 2021-03-21 18:17:39.921958+00:00 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "07a0d0f7e1b4" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("username", sa.String(), nullable=True), + sa.Column("url", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=False), + sa.Column("hashed_password", sa.String(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("is_superuser", sa.Boolean(), nullable=True), + sa.Column("is_approved", sa.Boolean(), nullable=True), + sa.Column("client_id", sa.String(), nullable=True), + sa.Column("client_secret", sa.String(), nullable=True), + sa.Column("access_token", sa.String(), nullable=True), + sa.Column("access_token_expires_at", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) + op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) + op.create_index(op.f("ix_users_username"), "users", ["username"], unique=False) + op.create_table( + "jobs", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("finished_at", sa.DateTime(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("inputs", sa.JSON(), nullable=True), + sa.Column("outputs", sa.JSON(), nullable=True), + sa.Column("tag", sa.String(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["users.id"],), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "projects", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("owner", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("tech", sa.String(), nullable=False), + sa.Column("callable_name", sa.String(), nullable=True), + sa.Column("exp_task_time", sa.String(), nullable=False), + sa.Column("cpu", sa.Integer(), nullable=True), + sa.Column("memory", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["users.id"],), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "owner", "title", "user_id", name="unique_owner_title_project" + ), + ) + op.create_index(op.f("ix_projects_id"), "projects", ["id"], unique=False) + op.create_index(op.f("ix_projects_owner"), "projects", ["owner"], unique=False) + op.create_index(op.f("ix_projects_title"), "projects", ["title"], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_projects_title"), table_name="projects") + op.drop_index(op.f("ix_projects_owner"), table_name="projects") + op.drop_index(op.f("ix_projects_id"), table_name="projects") + op.drop_table("projects") + op.drop_table("jobs") + op.drop_index(op.f("ix_users_username"), table_name="users") + op.drop_index(op.f("ix_users_id"), table_name="users") + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_table("users") + # ### end Alembic commands ### diff --git a/workers/cs_workers/services/api/alembic/versions/49437c661c80_add_app_location_for_deployments.py b/workers/cs_workers/services/api/alembic/versions/49437c661c80_add_app_location_for_deployments.py new file mode 100644 index 00000000..a5f0513e --- /dev/null +++ b/workers/cs_workers/services/api/alembic/versions/49437c661c80_add_app_location_for_deployments.py @@ -0,0 +1,28 @@ +"""Add app location for deployments + +Revision ID: 49437c661c80 +Revises: f027333560c0 +Create Date: 2021-04-12 14:32:50.280249+00:00 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "49437c661c80" +down_revision = "f027333560c0" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("projects", sa.Column("app_location", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("projects", "app_location") + # ### end Alembic commands ### diff --git a/workers/cs_workers/services/api/alembic/versions/f027333560c0_fix_cpu_and_memory_types.py b/workers/cs_workers/services/api/alembic/versions/f027333560c0_fix_cpu_and_memory_types.py new file mode 100644 index 00000000..474a04c4 --- /dev/null +++ b/workers/cs_workers/services/api/alembic/versions/f027333560c0_fix_cpu_and_memory_types.py @@ -0,0 +1,54 @@ +"""Fix cpu and memory types + +Revision ID: f027333560c0 +Revises: 07a0d0f7e1b4 +Create Date: 2021-03-22 14:27:21.523743+00:00 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "f027333560c0" +down_revision = "07a0d0f7e1b4" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "projects", + "cpu", + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=True, + ) + op.alter_column( + "projects", + "memory", + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=True, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "projects", + "memory", + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=True, + ) + op.alter_column( + "projects", + "cpu", + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/workers/cs_workers/services/api/database.py b/workers/cs_workers/services/api/database.py new file mode 100644 index 00000000..f410d237 --- /dev/null +++ b/workers/cs_workers/services/api/database.py @@ -0,0 +1,12 @@ +import os +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from .settings import settings + + +engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() diff --git a/workers/cs_workers/services/api/dependencies.py b/workers/cs_workers/services/api/dependencies.py new file mode 100644 index 00000000..752949fa --- /dev/null +++ b/workers/cs_workers/services/api/dependencies.py @@ -0,0 +1,63 @@ +from typing import Generator + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer + +from jose import jwt +from pydantic import ValidationError +from sqlalchemy.orm import Session + +from . import models, schemas, security +from .settings import settings +from .database import SessionLocal, engine + +reusable_oauth2 = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_PREFIX_STR}/login/access-token" +) + + +def get_db() -> Generator: + try: + db = SessionLocal() + yield db + finally: + db.close() + + +def get_current_user( + db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) +) -> models.User: + print("get_current_user") + try: + payload = jwt.decode( + token, settings.API_SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = schemas.TokenPayload(**payload) + except (jwt.JWTError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + user = db.query(models.User).filter(models.User.id == token_data.sub).one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + + +def get_current_active_user( + current_user: models.User = Depends(get_current_user), +) -> models.User: + print("get_current_active_user") + if not current_user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +def get_current_active_superuser( + current_user: models.User = Depends(get_current_user), +) -> models.User: + if not current_user.is_superuser: + raise HTTPException( + status_code=400, detail="The user doesn't have enough privileges" + ) + return current_user diff --git a/workers/cs_workers/services/api/main.py b/workers/cs_workers/services/api/main.py new file mode 100644 index 00000000..eb5e635e --- /dev/null +++ b/workers/cs_workers/services/api/main.py @@ -0,0 +1,25 @@ +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from .settings import settings +from .routers import users, login, projects, jobs, deployments + +app = FastAPI( + title=settings.PROJECT_NAME, openapi_url=f"{settings.API_PREFIX_STR}/openapi.json", +) + +# Set all CORS enabled origins +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + +app.include_router(login.router, prefix=settings.API_PREFIX_STR) +app.include_router(users.router, prefix=settings.API_PREFIX_STR) +app.include_router(projects.router, prefix=settings.API_PREFIX_STR) +app.include_router(jobs.router, prefix=settings.API_PREFIX_STR) +app.include_router(deployments.router, prefix=settings.API_PREFIX_STR) diff --git a/workers/cs_workers/services/api/models.py b/workers/cs_workers/services/api/models.py new file mode 100644 index 00000000..2dc40d19 --- /dev/null +++ b/workers/cs_workers/services/api/models.py @@ -0,0 +1,80 @@ +import uuid + +from sqlalchemy import ( + Boolean, + Column, + ForeignKey, + Integer, + String, + DateTime, + JSON, + Float, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from sqlalchemy.schema import UniqueConstraint + +from .database import Base + + +class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True, index=True) + username = Column(String, index=True) + url = Column(String, nullable=True) + email = Column(String, unique=True, index=True, nullable=False) + + hashed_password = Column(String, nullable=False) + + is_active = Column(Boolean(), default=True) + is_superuser = Column(Boolean(), default=False) + is_approved = Column(Boolean(), default=False) + + client_id = Column(String) + client_secret = Column(String) + access_token = Column(String) + access_token_expires_at = Column(DateTime) + + jobs = relationship("Job", back_populates="user") + projects = relationship("Project", back_populates="user") + + +class Job(Base): + __tablename__ = "jobs" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(Integer, ForeignKey("users.id")) + name = Column(String) + created_at = Column(DateTime) + finished_at = Column(DateTime) + status = Column(String) + inputs = Column(JSON) + outputs = Column(JSON) + tag = Column(String) + + user = relationship("User", back_populates="jobs") + + +class Project(Base): + __tablename__ = "projects" + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id")) + owner = Column(String, nullable=False, index=True) + title = Column(String, nullable=False, index=True) + app_location = Column(String, nullable=True) + tech = Column(String, nullable=False) + callable_name = Column(String) + exp_task_time = Column(String, nullable=False) + cpu = Column(Float) + memory = Column(Float) + + user = relationship("User", back_populates="projects") + + __table_args__ = ( + UniqueConstraint( + "owner", "title", "user_id", name="unique_owner_title_project", + ), + ) + + class Config: + orm_mode = True + extra = "ignore" diff --git a/workers/cs_workers/services/api/routers/__init__.py b/workers/cs_workers/services/api/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/workers/cs_workers/services/api/routers/deployments.py b/workers/cs_workers/services/api/routers/deployments.py new file mode 100644 index 00000000..bd23b2af --- /dev/null +++ b/workers/cs_workers/services/api/routers/deployments.py @@ -0,0 +1,162 @@ +import os + +from fastapi import APIRouter, Depends, Body, HTTPException +from sqlalchemy.orm import Session + +from cs_workers.models.clients import server +from .. import utils, models, schemas, dependencies as deps, security, settings + +incluster = os.environ.get("KUBERNETES_SERVICE_HOST", False) is not False + +PROJECT = os.environ.get("PROJECT") + + +router = APIRouter(prefix="/deployments", tags=["deployments"]) + + +@router.post( + "/{owner}/{title}/", response_model=schemas.DeploymentReadyStats, status_code=201 +) +def create_deployment( + owner: str, + title: str, + data: schemas.DeploymentCreate = Body(...), + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + print("create deployment", data) + project: models.Project = ( + db.query(models.Project) + .filter( + models.Project.owner == owner, + models.Project.title == title, + models.Project.user_id == user.id, + ) + .one_or_none() + ) + + if not project: + raise HTTPException(status_code=404, detail="Project not found.") + + if project.tech not in ("dash", "bokeh"): + return HTTPException(status_code=400, detail=f"Unsuported tech: {project.tech}") + + project_data = schemas.Project.from_orm(project).dict() + utils.set_resource_requirements(project_data) + + viz = server.Server( + project=PROJECT, + owner=project.owner, + title=project.title, + tag=data.tag, + model_config=project_data, + callable_name=project.callable_name, + deployment_name=data.deployment_name, + incluster=incluster, + viz_host=settings.settings.VIZ_HOST, + namespace=settings.settings.PROJECT_NAMESPACE, + ) + dep = viz.deployment_from_cluster() + if dep is not None: + raise HTTPException(status_code=400, detail="Deployment is already running.") + + viz.configure() + viz.create() + ready_stats = schemas.DeploymentReadyStats(**viz.ready_stats()) + return ready_stats + + +@router.get( + "/{owner}/{title}/{deployment_name}/", + response_model=schemas.DeploymentReadyStats, + status_code=200, +) +def get_deployment( + owner: str, + title: str, + deployment_name: str, + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + project: models.Project = ( + db.query(models.Project) + .filter( + models.Project.owner == owner, + models.Project.title == title, + models.Project.user_id == user.id, + ) + .one_or_none() + ) + + if not project: + raise HTTPException(status_code=404, detail="Project not found.") + + if project.tech not in ("dash", "bokeh"): + return HTTPException(status_code=400, detail=f"Unsuported tech: {project.tech}") + + project_data = schemas.Project.from_orm(project).dict() + utils.set_resource_requirements(project_data) + + viz = server.Server( + project=PROJECT, + owner=project.owner, + title=project.title, + tag=None, + model_config=project_data, + callable_name=project.callable_name, + deployment_name=deployment_name, + incluster=incluster, + viz_host=settings.settings.VIZ_HOST, + namespace=settings.settings.PROJECT_NAMESPACE, + ) + + ready_stats = schemas.DeploymentReadyStats(**viz.ready_stats()) + return ready_stats + + +@router.delete( + "/{owner}/{title}/{deployment_name}/", + response_model=schemas.DeploymentDelete, + status_code=200, +) +def delete_deployment( + owner: str, + title: str, + deployment_name: str, + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + project: models.Project = ( + db.query(models.Project) + .filter( + models.Project.owner == owner, + models.Project.title == title, + models.Project.user_id == user.id, + ) + .one_or_none() + ) + + if not project: + raise HTTPException(status_code=404, detail="Project not found.") + + if project.tech not in ("dash", "bokeh"): + return HTTPException(status_code=400, detail=f"Unsuported tech: {project.tech}") + + project_data = schemas.Project.from_orm(project).dict() + utils.set_resource_requirements(project_data) + + viz = server.Server( + project=PROJECT, + owner=project.owner, + title=project.title, + tag=None, + model_config=project_data, + callable_name=project.callable_name, + deployment_name=deployment_name, + incluster=incluster, + viz_host=settings.settings.VIZ_HOST, + namespace=settings.settings.PROJECT_NAMESPACE, + ) + + delete = schemas.DeploymentDelete(**viz.delete()) + return delete diff --git a/workers/cs_workers/services/api/routers/jobs.py b/workers/cs_workers/services/api/routers/jobs.py new file mode 100644 index 00000000..6de51063 --- /dev/null +++ b/workers/cs_workers/services/api/routers/jobs.py @@ -0,0 +1,152 @@ +from datetime import datetime +import os + +import httpx +from fastapi import APIRouter, Depends, Body, HTTPException +from sqlalchemy.orm import Session + +from cs_workers.models.clients import job +from .. import utils, models, schemas, dependencies as deps, security, settings + +incluster = os.environ.get("KUBERNETES_SERVICE_HOST", False) is not False + +PROJECT = os.environ.get("PROJECT") + + +router = APIRouter(prefix="/jobs", tags=["jobs"]) + + +@router.get("/callback/{job_id}/", status_code=201, response_model=schemas.Job) +def job_callback( + job_id: str, db: Session = Depends(deps.get_db), +): + instance: models.Job = db.query(models.Job).filter( + models.Job.id == job_id + ).one_or_none() + if instance is None: + raise HTTPException(status_code=404, detail="Job not found.") + print(instance.finished_at) + if instance.finished_at: + raise HTTPException( + status_code=403, detail="No permission to retrieve job once it's finished." + ) + + if instance.status == "CREATED": + instance.status = "RUNNING" + db.add(instance) + db.commit() + db.refresh(instance) + + print(instance.inputs) + + return instance + + +@router.post("/callback/{job_id}/", status_code=201, response_model=schemas.Job) +async def finish_job( + job_id: str, + task: schemas.TaskComplete = Body(...), + db: Session = Depends(deps.get_db), +): + print("got data for ", job_id) + instance = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() + if instance is None: + raise HTTPException(status_code=404, detail="Job not found.") + + if instance.finished_at: + raise HTTPException(status_code=400, detail="Job already marked as complete.") + + instance.outputs = task.outputs + instance.status = task.status + instance.finished_at = datetime.utcnow() + + db.add(instance) + db.commit() + db.refresh(instance) + + user = instance.user + await security.ensure_cs_access_token(db, user) + async with httpx.AsyncClient() as client: + resp = await client.post( + f"http://outputs-processor/{job_id}/", + json={ + "url": user.url, + "headers": {"Authorization": f"Bearer {user.access_token}"}, + "task": task.dict(), + }, + ) + print(resp.text) + resp.raise_for_status() + + return instance + + +@router.post("/{owner}/{title}/", response_model=schemas.Job, status_code=201) +def create_job( + owner: str, + title: str, + task: schemas.Task = Body(...), + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + print(owner, title) + print(task.task_kwargs) + project = ( + db.query(models.Project) + .filter( + models.Project.owner == owner, + models.Project.title == title, + models.Project.user_id == user.id, + ) + .one_or_none() + ) + + if not project: + raise HTTPException(status_code=404, detail="Project not found.") + + task_name, _, task_kwargs, tag = ( + task.task_name, + task.task_id, + task.task_kwargs, + task.tag, + ) + + instance = models.Job( + user_id=user.id, + name=task_name, + created_at=datetime.utcnow(), + finished_at=None, + inputs=task_kwargs, + tag=tag, + status="CREATED", + ) + db.add(instance) + db.commit() + db.refresh(instance) + + project_data = schemas.Project.from_orm(project).dict() + utils.set_resource_requirements(project_data) + + if settings.settings.WORKERS_API_HOST: + url = f"https://{settings.settings.WORKERS_API_HOST}" + else: + url = f"http://api.{settings.settings.NAMESPACE}.svc.cluster.local" + + url += settings.settings.API_PREFIX_STR + + client = job.Job( + PROJECT, + owner, + title, + tag=tag, + model_config=project_data, + job_id=instance.id, + callback_url=f"{url}/jobs/callback/{instance.id}/", + route_name=task_name, + incluster=incluster, + namespace=settings.settings.PROJECT_NAMESPACE, + ) + + client.create() + + return instance diff --git a/workers/cs_workers/services/api/routers/login.py b/workers/cs_workers/services/api/routers/login.py new file mode 100644 index 00000000..4055d759 --- /dev/null +++ b/workers/cs_workers/services/api/routers/login.py @@ -0,0 +1,48 @@ +from datetime import timedelta, datetime +from typing import Any, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi.security import OAuth2PasswordRequestForm + +import pytz +from sqlalchemy.orm import Session + +from .. import security +from ..models import User +from .. import schemas +from .. import dependencies as deps +from ..settings import settings + +router = APIRouter(tags=["login"]) + + +def authenticate(db: Session, *, username: str, password: str) -> Optional[User]: + user = db.query(User).filter(User.username == username).one_or_none() + if not user: + return None + if not security.verify_password(password, user.hashed_password): + return None + return user + + +@router.post("/login/access-token", response_model=schemas.Token) +def login_access_token( + db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends() +) -> Any: + """ + OAuth2 compatible token login, get an access token for future requests + """ + user = authenticate(db, username=form_data.username, password=form_data.password) + + if not user: + raise HTTPException(status_code=400, detail="Incorrect username or password") + elif not user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + return { + "access_token": security.create_access_token( + user.id, expires_delta=access_token_expires + ), + "token_type": "bearer", + "expires_at": datetime.now().replace(tzinfo=pytz.UTC) + access_token_expires, + } diff --git a/workers/cs_workers/services/api/routers/projects.py b/workers/cs_workers/services/api/routers/projects.py new file mode 100644 index 00000000..abd1ac73 --- /dev/null +++ b/workers/cs_workers/services/api/routers/projects.py @@ -0,0 +1,48 @@ +from typing import List + +from fastapi import APIRouter, Depends, Body +from sqlalchemy.orm import Session + +from .. import models, schemas, dependencies as deps + +router = APIRouter(prefix="/projects", tags=["projects"]) + + +@router.post("/sync/", response_model=List[schemas.Project], status_code=200) +def sync_projects( + projects: List[schemas.ProjectSync] = Body(...), + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + orm_projects = [] + for project in projects: + orm_project = ( + db.query(models.Project) + .filter( + models.Project.title == project.title, + models.Project.owner == project.owner, + models.Project.user_id == user.id, + ) + .one_or_none() + ) + project_data = project.dict() + if orm_project is None: + print("creating object from data", project_data) + orm_project = models.Project(**project_data, user_id=user.id) + else: + print("updating object from data", project_data) + for attr, val in project.dict().items(): + print("setting", attr, val) + setattr(orm_project, attr, val) + orm_projects.append(orm_project) + db.add_all(orm_projects) + db.commit() + return orm_projects + + +@router.get("/", response_model=List[schemas.Project], status_code=200) +def get_projects( + db: Session = Depends(deps.get_db), + user: schemas.User = Depends(deps.get_current_active_user), +): + return user.projects diff --git a/workers/cs_workers/services/api/routers/users.py b/workers/cs_workers/services/api/routers/users.py new file mode 100644 index 00000000..bcb34b83 --- /dev/null +++ b/workers/cs_workers/services/api/routers/users.py @@ -0,0 +1,92 @@ +import base64 +from typing import Any + +import httpx +from fastapi import APIRouter, Body, Depends, HTTPException +from sqlalchemy.orm import Session +from pydantic.networks import EmailStr, AnyHttpUrl # pylint: disable=no-name-in-module + +from .. import schemas, models, dependencies as deps, security + +router = APIRouter(prefix="/users", tags=["users"]) + + +@router.get("/me", response_model=schemas.User) +def read_user_me( + db: Session = Depends(deps.get_db), + current_user: models.User = Depends(deps.get_current_active_user), +) -> Any: + """ + Get current user. + """ + return current_user + + +@router.post("/", response_model=schemas.User, status_code=201) +def create_user( + *, + db: Session = Depends(deps.get_db), + password: str = Body(...), + email: EmailStr = Body(...), + url: AnyHttpUrl = Body(...), + username: str = Body(None), + client_id: str = Body(...), + client_secret: str = Body(...), +) -> models.User: + """ + Create new user. + """ + user = db.query(models.User).filter(models.User.username == username).one_or_none() + if user: + raise HTTPException( + status_code=400, + detail="The user with this username already exists in the system", + ) + user_in = schemas.UserCreate( + password=password, email=email, username=username, url=url + ) + user_db = models.User( + email=user_in.email, + username=user_in.username, + url=user_in.url, + hashed_password=security.get_password_hash(user_in.password), + client_id=client_id, + client_secret=client_secret, + ) + db.add(user_db) + db.commit() + db.refresh(user_db) + return user_db + + +@router.get("/ping/", status_code=200) +async def ping( + *, + db: Session = Depends(deps.get_db), + current_user: models.User = Depends(deps.get_current_active_user), +): + await security.ensure_cs_access_token(db, current_user) + + +@router.post("/approve/", response_model=schemas.User) +def approve_user( + *, + db: Session = Depends(deps.get_db), + current_super_user: models.User = Depends(deps.get_current_active_superuser), + user_approve: schemas.UserApprove = Body(...), +) -> models.User: + """ + Create new user. + """ + user: models.User = db.query(models.User).filter( + models.User.username == user_approve.username + ).one_or_none() + if not user: + raise HTTPException( + status_code=400, detail="The user with this username does not exist", + ) + user.is_approved = user_approve.is_approved + db.add(user) + db.commit() + db.refresh(user) + return user diff --git a/workers/cs_workers/services/api/schemas.py b/workers/cs_workers/services/api/schemas.py new file mode 100644 index 00000000..5c771dcf --- /dev/null +++ b/workers/cs_workers/services/api/schemas.py @@ -0,0 +1,149 @@ +from datetime import datetime +from typing import List, Optional, Dict, Optional, Any +from enum import Enum +import uuid + +from pydantic import BaseModel, Json # pylint: disable=no-name-in-module +from pydantic.networks import EmailStr, AnyHttpUrl # pylint: disable=no-name-in-module + + +class JobBase(BaseModel): + user_id: int + created_at: datetime + name: str + created_at: datetime + finished_at: Optional[datetime] + status: str + inputs: Optional[Dict] + outputs: Optional[Dict] + traceback: Optional[str] + tag: str + + +class JobCreate(JobBase): + pass + + +class Job(JobBase): + id: uuid.UUID + + class Config: + orm_mode = True + + +class TaskComplete(BaseModel): + model_version: Optional[str] + outputs: Optional[Dict] + traceback: Optional[str] + version: Optional[str] + meta: Dict # Dict[str, str] + status: str + task_name: str + + +class Task(BaseModel): + task_id: Optional[str] + task_name: str + task_kwargs: Dict # Dict[str, str] + tag: str + + +# Shared properties +class UserBase(BaseModel): + email: Optional[EmailStr] = None + username: Optional[str] = None + url: Optional[AnyHttpUrl] + is_approved: Optional[bool] + + is_active: Optional[bool] = True + + +# Properties to receive via API on creation +class UserCreate(UserBase): + email: EmailStr = None + username: str = None + url: AnyHttpUrl + password: str + + +class UserApprove(UserBase): + username: str + is_approved: bool + + +class UserInDBBase(UserBase): + class Config: + orm_mode = True + + +# Additional properties to return via API +class User(UserInDBBase): + pass + + +# Additional properties stored in DB +class UserInDB(UserInDBBase): + id: Optional[int] = None + hashed_password: str + + +class Token(BaseModel): + access_token: str + token_type: str + expires_at: datetime + + +class TokenPayload(BaseModel): + sub: Optional[int] = None + + +class CSOauthResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + scope: str + + +class ProjectSync(BaseModel): + owner: str + title: str + tech: str + callable_name: Optional[str] + app_location: Optional[str] + exp_task_time: int + cpu: float + memory: float + + +class Project(ProjectSync): + id: int + + class Config: + orm_mode = True + extra = "ignore" + + +class DeploymentCreate(BaseModel): + tag: str + deployment_name: str + + +class ReadyStats(BaseModel): + created_at: Optional[datetime] + ready: bool + + +class DeploymentReadyStats(BaseModel): + deployment: ReadyStats + svc: ReadyStats + ingressroute: ReadyStats + + +class Deleted(BaseModel): + deleted: bool + + +class DeploymentDelete(BaseModel): + deployment: Deleted + svc: Deleted + ingressroute: Deleted diff --git a/workers/cs_workers/services/api/scripts/create_super_user.py b/workers/cs_workers/services/api/scripts/create_super_user.py new file mode 100644 index 00000000..9d1139b6 --- /dev/null +++ b/workers/cs_workers/services/api/scripts/create_super_user.py @@ -0,0 +1,31 @@ +import argparse +from getpass import getpass + +from cs_workers.services.api.security import get_password_hash +from cs_workers.services.api.models import User +from cs_workers.services.api.database import SessionLocal +from cs_workers.services.api.schemas import User as UserSchema + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--username") + parser.add_argument("--email") + args = parser.parse_args() + + password = getpass() + assert password, "Password required." + + user = User( + username=args.username, + hashed_password=get_password_hash(password), + email=args.email, + is_superuser=True, + is_active=True, + url=None, + ) + session = SessionLocal() + session.add(user) + session.commit() + session.refresh(user) + print("User created successfully:") + print(UserSchema.from_orm(user).dict()) diff --git a/workers/cs_workers/services/api/security.py b/workers/cs_workers/services/api/security.py new file mode 100644 index 00000000..bc79d2a7 --- /dev/null +++ b/workers/cs_workers/services/api/security.py @@ -0,0 +1,69 @@ +from datetime import datetime, timedelta +from typing import Any, Union + +import httpx +from jose import jwt +from passlib.context import CryptContext +from sqlalchemy.orm import Session + +from fastapi import HTTPException + +from .settings import settings +from . import schemas +from . import models + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +ALGORITHM = "HS256" + + +def create_access_token( + subject: Union[str, Any], expires_delta: timedelta = None +) -> str: + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) + to_encode = {"exp": expire, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, settings.API_SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) + + +async def ensure_cs_access_token(db: Session, user: models.User): + missing_token = user.access_token is None + is_expired = ( + user.access_token_expires_at is not None + and user.access_token_expires_at < (datetime.utcnow() - timedelta(seconds=60)) + ) + if missing_token or is_expired: + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{user.url}/o/token/", + data={ + "grant_type": "client_credentials", + "client_id": user.client_id, + "client_secret": user.client_secret, + }, + ) + if resp.status_code != 200: + raise HTTPException(status_code=400, detail=resp.text) + data = schemas.CSOauthResponse(**resp.json()) + user.access_token = data.access_token + user.access_token_expires_at = datetime.utcnow() + timedelta( + seconds=data.expires_in + ) + db.add(user) + db.commit() + db.refresh(user) + return user diff --git a/workers/cs_workers/services/api/settings.py b/workers/cs_workers/services/api/settings.py new file mode 100644 index 00000000..96e5da40 --- /dev/null +++ b/workers/cs_workers/services/api/settings.py @@ -0,0 +1,93 @@ +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, PostgresDsn, validator + +NAMESPACE_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + + +class Settings(BaseSettings): + API_PREFIX_STR: str = "/api/v1" + API_SECRET_KEY: str + # 60 minutes * 24 hours * 8 days = 8 days + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 + SERVER_NAME: Optional[str] + SERVER_HOST: Optional[AnyHttpUrl] + + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [ + "http://10.0.0.137:5000", + "http://localhost:5000", + "https://hdoupe.ngrok.io", + ] + + WORKERS_API_HOST: Optional[str] + VIZ_HOST: str + + @validator("BACKEND_CORS_ORIGINS", pre=True) + def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: + if isinstance(v, str) and not v.startswith("["): + return [i.strip() for i in v.split(",")] + elif isinstance(v, (list, str)): + return v + raise ValueError(v) + + PROJECT_NAMESPACE: str + + @validator("PROJECT_NAMESPACE", pre=True) + def get_project_namespace(cls, v: Optional[str]) -> str: + return v or "default" + + NAMESPACE: Optional[str] + + @validator("NAMESPACE", pre=True) + def get_namespace(cls, v: Optional[str]) -> str: + if v: + return v + elif Path(NAMESPACE_PATH).exists(): + with open(NAMESPACE_PATH) as f: + return f.read().strip() + else: + return "default" + + PROJECT_NAME: str = "C/S Cluster Api" + SENTRY_DSN: Optional[HttpUrl] = None + + @validator("SENTRY_DSN", pre=True) + def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]: + if v and len(v) == 0: + return None + return v + + DB_HOST: str + DB_USER: str + DB_PASS: str + DB_NAME: str + + TEST_DB_NAME: str = "test" + TEST_DB_PASS: str = os.environ.get("TEST_DB_PASS", "test") + + SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None + + @validator("SQLALCHEMY_DATABASE_URI", pre=True) + def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: + if isinstance(v, str): + return v + return "{scheme}://{user}:{password}@{host}{path}".format( + scheme="postgresql", + user=values.get("DB_USER"), + password=values.get("DB_PASS"), + host=values.get("DB_HOST"), + path=f"/{values.get('DB_NAME')}", + ) + + FIRST_SUPERUSER: Optional[EmailStr] + FIRST_SUPERUSER_PASSWORD: Optional[str] + + JOB_NAMESPACE: str = "worker-api" + + class Config: + case_sensitive = True + + +settings = Settings() diff --git a/workers/cs_workers/services/api/tests/__init__.py b/workers/cs_workers/services/api/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/workers/cs_workers/services/api/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/workers/cs_workers/services/api/tests/conftest.py b/workers/cs_workers/services/api/tests/conftest.py new file mode 100644 index 00000000..4d4b3404 --- /dev/null +++ b/workers/cs_workers/services/api/tests/conftest.py @@ -0,0 +1,153 @@ +from typing import Dict, Generator + +import pytest +from fastapi.testclient import TestClient + +import sqlalchemy as sa +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker, exc + +from ..settings import settings +from ..database import SessionLocal +from ..main import app +from ..dependencies import get_db +from .. import models, schemas, security + + +SQLALCHEMY_DATABASE_URI = f"postgresql://{settings.DB_USER}:{settings.TEST_DB_PASS}@{settings.DB_HOST}/{settings.TEST_DB_NAME}" +assert settings.DB_NAME != settings.TEST_DB_NAME + +engine = create_engine(SQLALCHEMY_DATABASE_URI, pool_pre_ping=True) + +Base = declarative_base() + +Base.metadata.create_all(bind=engine) + + +# Adapted from: +# https://github.com/jeancochrane/pytest-flask-sqlalchemy/blob/c109469f83450b8c5ff5de962faa1105064f5619/pytest_flask_sqlalchemy/fixtures.py#L25-L84 +@pytest.fixture(scope="function") +def db(request) -> Generator: + connection = engine.connect() + transaction = connection.begin() + TestingSessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=connection + ) + session = TestingSessionLocal() + + # Make sure the session, connection, and transaction can't be closed by accident in + # the codebase + connection.force_close = connection.close + transaction.force_rollback = transaction.rollback + + connection.close = lambda: None + transaction.rollback = lambda: None + session.close = lambda: None + + session.begin_nested() + # Each time the SAVEPOINT for the nested transaction ends, reopen it + @sa.event.listens_for(session, "after_transaction_end") + def restart_savepoint(session, trans): + if trans.nested and not trans._parent.nested: + # ensure that state is expired the way + # session.commit() at the top level normally does + session.expire_all() + + session.begin_nested() + + # Force the connection to use nested transactions + connection.begin = connection.begin_nested + + # If an object gets moved to the 'detached' state by a call to flush the session, + # add it back into the session (this allows us to see changes made to objects + # in the context of a test, even when the change was made elsewhere in + # the codebase) + @sa.event.listens_for(session, "persistent_to_detached") + @sa.event.listens_for(session, "deleted_to_detached") + def rehydrate_object(session, obj): + session.add(obj) + + @request.addfinalizer + def teardown_transaction(): + # Delete the session + session.close() + + # Rollback the transaction and return the connection to the pool + transaction.force_rollback() + connection.force_close() + + app.dependency_overrides[get_db] = lambda: session + return session + + +# # sa 1.4??? +# # https://github.com/jeancochrane/pytest-flask-sqlalchemy/blob/c109469f83450b8c5ff5de962faa1105064f5619/pytest_flask_sqlalchemy/fixtures.py#L25-L84 +# @pytest.fixture(scope="function") +# def db(request) -> Generator: +# engine = create_engine(SQLALCHEMY_DATABASE_URI, pool_pre_ping=True) +# connection = engine.connect() +# transaction = connection.begin() + +# Base = declarative_base() + +# Base.metadata.create_all(bind=connection) + +# TestingSessionLocal = sessionmaker( +# autocommit=False, autoflush=False, bind=connection +# ) +# session = TestingSessionLocal() +# try: +# session.begin() + +# app.dependency_overrides[get_db] = lambda: session +# yield session +# finally: +# session.rollback() +# transaction.rollback() + + +@pytest.fixture(scope="function") +def client() -> Generator: + with TestClient(app) as c: + yield c + + +@pytest.fixture(scope="function") +def new_user(db): + user_ = models.User( + username="test", + email="test@test.com", + url="http://localhost:8000", + hashed_password=security.get_password_hash("heyhey2222"), + client_id="abc123", + client_secret="abc123", + ) + db.add(user_) + db.commit() + db.refresh(user_) + return user_ + + +@pytest.fixture(scope="function") +def user(db, new_user): + new_user.approved = True + db.add(new_user) + db.commit() + db.refresh(new_user) + return new_user + + +@pytest.fixture(scope="function") +def superuser(db): + user_ = models.User( + username="super-user", + email="super-user@test.com", + url="http://localhost:8000", + hashed_password=security.get_password_hash("heyhey2222"), + is_superuser=True, + ) + db.add(user_) + db.commit() + db.refresh(user_) + yield user_ diff --git a/workers/cs_workers/services/api/tests/test_projects.py b/workers/cs_workers/services/api/tests/test_projects.py new file mode 100644 index 00000000..cabdf844 --- /dev/null +++ b/workers/cs_workers/services/api/tests/test_projects.py @@ -0,0 +1,85 @@ +from .utils import get_access_token +from ..settings import settings +from ..models import Project + + +class TestProjects: + def test_sync_projects(self, db, client, user): + access_token = get_access_token(client, user) + assert access_token + data = { + "owner": "test", + "title": "test-app", + "tech": "bokeh", + "callable_name": "hello", + "exp_task_time": 10, + "cpu": 4, + "memory": 10, + } + + resp = client.post( + f"{settings.API_PREFIX_STR}/projects/sync/", + json=[data], + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200 + assert resp.json() + assert ( + db.query(Project) + .filter(Project.owner == "test", Project.title == "test-app") + .one() + ) + + resp = client.post( + f"{settings.API_PREFIX_STR}/projects/sync/", + json=[data], + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200 + assert resp.json() + assert ( + db.query(Project) + .filter(Project.owner == "test", Project.title == "test-app") + .one() + ) + + resp = client.post( + f"{settings.API_PREFIX_STR}/projects/sync/", + json=[dict(data, title="test-app-another")], + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200 + assert resp.json() + assert ( + db.query(Project) + .filter(Project.owner == "test", Project.title == "test-app-another") + .one() + ) + assert db.query(Project).count() == 2 + + def test_get_projects(self, db, client, user): + access_token = get_access_token(client, user) + assert access_token + data = { + "owner": "test", + "title": "test-app", + "tech": "bokeh", + "callable_name": "hello", + "exp_task_time": 10, + "cpu": 4, + "memory": 10, + } + for i in range(3): + resp = client.post( + f"{settings.API_PREFIX_STR}/projects/sync/", + json=[dict(data, title=f"new-app-{i}")], + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200 + + resp = client.get( + f"{settings.API_PREFIX_STR}/projects/", + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200 + assert len(resp.json()) == db.query(Project).count() == 3 diff --git a/workers/cs_workers/services/api/tests/test_users.py b/workers/cs_workers/services/api/tests/test_users.py new file mode 100644 index 00000000..4b22ebab --- /dev/null +++ b/workers/cs_workers/services/api/tests/test_users.py @@ -0,0 +1,68 @@ +from .. import models + + +class TestUsers: + def test_create_user(self, db, client): + resp = client.post( + "/api/v1/users/", + json={ + "email": "new_user@test.com", + "username": "new_user", + "password": "hello world", + "url": "https://example.com", + "client_id": "abc123", + "client_secret": "abc123", + }, + ) + assert resp.status_code == 201, resp.text + assert resp.json() == { + "email": "new_user@test.com", + "username": "new_user", + "url": "https://example.com", + "is_approved": False, + "is_active": True, + } + + def test_login_user(self, db, client, new_user): + resp = client.post( + "/api/v1/login/access-token", + data={"username": "test", "password": "heyhey2222"}, + ) + assert resp.status_code == 200, f"Got {resp.status_code}: {resp.text}" + assert resp.json().get("access_token") + assert resp.json().get("expires_at") + + def test_get_current_user(self, db, client, new_user): + resp = client.post( + "/api/v1/login/access-token", + data={"username": "test", "password": "heyhey2222"}, + ) + assert resp.status_code == 200, f"Got {resp.status_code}: {resp.text}" + access_token = resp.json()["access_token"] + resp = client.get( + "/api/v1/users/me/", headers={"Authorization": f"Bearer {access_token}"} + ) + assert resp.status_code == 200, f"Got {resp.status_code}: {resp.text}" + assert resp.json()["username"] + + def test_approve_user(self, db, client, new_user, superuser): + resp = client.post( + "/api/v1/login/access-token", + data={"username": "super-user", "password": "heyhey2222"}, + ) + + assert resp.status_code == 200 + + assert superuser.is_superuser + assert not new_user.is_approved + print("user.username", new_user.username, type(new_user.username)) + access_token = resp.json()["access_token"] + resp = client.post( + "/api/v1/users/approve/", + json={"username": new_user.username, "is_approved": True}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert resp.status_code == 200, f"Got {resp.status_code}: {resp.text}" + + db.refresh(new_user) + assert new_user.is_approved diff --git a/workers/cs_workers/services/api/tests/utils.py b/workers/cs_workers/services/api/tests/utils.py new file mode 100644 index 00000000..263aec70 --- /dev/null +++ b/workers/cs_workers/services/api/tests/utils.py @@ -0,0 +1,7 @@ +def get_access_token(test_client, user): + resp = test_client.post( + "/api/v1/login/access-token", + data={"username": "test", "password": "heyhey2222"}, + ) + assert resp.status_code == 200, f"Got {resp.status_code}: {resp.text}" + return resp.json()["access_token"] diff --git a/workers/cs_workers/services/api/utils.py b/workers/cs_workers/services/api/utils.py new file mode 100644 index 00000000..2093a33d --- /dev/null +++ b/workers/cs_workers/services/api/utils.py @@ -0,0 +1,11 @@ +import math + + +def set_resource_requirements(project_data): + mem = float(project_data.pop("memory")) + cpu = float(project_data.pop("cpu")) + if cpu and mem: + project_data["resources"] = { + "requests": {"memory": f"{mem}G", "cpu": cpu}, + "limits": {"memory": f"{math.ceil(mem * 1.2)}G", "cpu": cpu,}, + } diff --git a/workers/cs_workers/services/manage.py b/workers/cs_workers/services/manage.py index aa1e2fd6..6b0d5f17 100644 --- a/workers/cs_workers/services/manage.py +++ b/workers/cs_workers/services/manage.py @@ -11,7 +11,8 @@ from cs_deploy.config import workers_config from cs_workers.services.secrets import ServicesSecrets -from cs_workers.services import scheduler + +# from cs_workers.services import scheduler CURR_PATH = Path(os.path.abspath(os.path.dirname(__file__))) BASE_PATH = CURR_PATH / ".." @@ -49,7 +50,7 @@ class Manager: Deploy and manage Compute Studio compute cluster: - build, tag, and push the docker images for the flask app and compute.studio modeling apps. - - write k8s config files for the scheduler deployment and the + - write k8s config files for the workers_api deployment and the compute.studio modeling app deployments. - apply k8s config files to an existing compute cluster. @@ -93,36 +94,6 @@ def __init__( self.templates_dir = BASE_PATH / Path("templates") self.dockerfiles_dir = BASE_PATH / Path("dockerfiles") - with open( - self.templates_dir / "services" / "scheduler-Deployment.template.yaml", "r" - ) as f: - self.scheduler_template = yaml.safe_load(f.read()) - - with open( - self.templates_dir / "services" / "scheduler-ingressroute.template.yaml", - "r", - ) as f: - self.scheduler_ir_template = yaml.safe_load(f.read()) - - with open( - self.templates_dir - / "services" - / "outputs-processor-Deployment.template.yaml", - "r", - ) as f: - self.outputs_processor_template = yaml.safe_load(f.read()) - with open( - self.templates_dir / "services" / "outputs-processor-ServiceAccount.yaml", - "r", - ) as f: - self.outputs_processor_serviceaccount = yaml.safe_load(f.read()) - - with open( - self.templates_dir / "services" / "redis-master-Deployment.template.yaml", - "r", - ) as f: - self.redis_master_template = yaml.safe_load(f.read()) - with open(self.templates_dir / "secret.template.yaml", "r") as f: self.secret_template = yaml.safe_load(f.read()) @@ -131,33 +102,26 @@ def __init__( def build(self): """ - Build, tag, and push base images for the scheduler app. + Build, tag, and push base images for the workers_api app. Note: distributed and celerybase are tagged as "latest." All other apps pull from either distributed:latest or celerybase:latest. """ distributed = self.dockerfiles_dir / "Dockerfile" - redis = self.dockerfiles_dir / "Dockerfile.redis" outputs_processor = self.dockerfiles_dir / "Dockerfile.outputs_processor" - scheduler = self.dockerfiles_dir / "Dockerfile.scheduler" + workers_api = self.dockerfiles_dir / "Dockerfile.workers_api" run(f"docker build -t distributed:latest -f {distributed} ./") - run(f"docker build -t redis-python:{self.tag} -f {redis} ./") run(f"docker build -t outputs_processor:{self.tag} -f {outputs_processor} ./") - run(f"docker build -t scheduler:{self.tag} -f {scheduler} ./") + run(f"docker build -t workers_api:{self.tag} -f {workers_api} ./") def push(self): run(f"docker tag distributed {self.cr}/{self.project}/distributed:latest") - run( - f"docker tag redis-python:{self.tag} {self.cr}/{self.project}/redis-python:{self.tag}" - ) - run( f"docker tag outputs_processor:{self.tag} {self.cr}/{self.project}/outputs_processor:{self.tag}" ) - run( - f"docker tag scheduler:{self.tag} {self.cr}/{self.project}/scheduler:{self.tag}" + f"docker tag workers_api:{self.tag} {self.cr}/{self.project}/workers_api:{self.tag}" ) if self.use_kind: @@ -166,119 +130,12 @@ def push(self): cmd_prefix = "docker push" run(f"{cmd_prefix} {self.cr}/{self.project}/distributed:latest") - run(f"{cmd_prefix} {self.cr}/{self.project}/redis-python:{self.tag}") run(f"{cmd_prefix} {self.cr}/{self.project}/outputs_processor:{self.tag}") - run(f"{cmd_prefix} {self.cr}/{self.project}/scheduler:{self.tag}") + run(f"{cmd_prefix} {self.cr}/{self.project}/workers_api:{self.tag}") def config(self, update_redis=False, update_dns=False): - config_filenames = [ - "scheduler-Service.yaml", - "scheduler-RBAC.yaml", - "outputs-processor-Service.yaml", - "job-cleanup-Job.yaml", - "job-cleanup-RBAC.yaml", - ] - if update_redis: - config_filenames.append("redis-master-Service.yaml") - for filename in config_filenames: - with open(self.templates_dir / "services" / f"{filename}", "r") as f: - configs = yaml.safe_load_all(f.read()) - for config in configs: - name = config["metadata"]["name"] - kind = config["kind"] - self.write_config(f"{name}-{kind}.yaml", config) - self.write_scheduler_deployment() if update_dns: - self.write_scheduler_ingressroute() self.write_cloudflare_api_token() - self.write_outputs_processor_deployment() - self.write_secret() - if update_redis: - self.write_redis_deployment() - - def write_scheduler_deployment(self): - """ - Write scheduler deployment file. Only step is filling in the image uri. - """ - deployment = copy.deepcopy(self.scheduler_template) - deployment["spec"]["template"]["spec"]["containers"][0][ - "image" - ] = f"gcr.io/{self.project}/scheduler:{self.tag}" - deployment["spec"]["template"]["spec"]["containers"][0]["env"] += [ - {"name": "VIZ_HOST", "value": self.viz_host}, - ] - self.write_config("scheduler-Deployment.yaml", deployment) - - return deployment - - def write_scheduler_ingressroute(self): - """ - Write scheduler ingressroute file. Only step is filling in the cluster host. - """ - ir = copy.deepcopy(self.scheduler_ir_template) - ir["spec"]["routes"][0]["match"] = f"Host(`{self.cluster_host}`)" - self.write_config("scheduler-ingressroute.yaml", ir) - - return ir - - def write_outputs_processor_deployment(self): - """ - Write outputs processor deployment file. Only step is filling - in the image uri. - """ - deployment = copy.deepcopy(self.outputs_processor_template) - deployment["spec"]["template"]["spec"]["containers"][0][ - "image" - ] = f"gcr.io/{self.project}/outputs_processor:{self.tag}" - - self.write_config( - "outputs-processor-ServiceAccount.yaml", - self.outputs_processor_serviceaccount, - ) - self.write_config("outputs-processor-Deployment.yaml", deployment) - - return deployment - - def write_redis_deployment(self): - deployment = copy.deepcopy(self.redis_master_template) - container = deployment["spec"]["template"]["spec"]["containers"][0] - container["image"] = f"gcr.io/{self.project}/redis-python:{self.tag}" - redis_secrets = self.redis_secrets() - for name, sec in redis_secrets.items(): - if sec is not None: - container["env"].append( - { - "name": name, - "valueFrom": { - "secretKeyRef": {"key": name, "name": "worker-secret"} - }, - } - ) - - if workers_config.get("redis"): - redis_config = workers_config["redis"] - assert ( - redis_config.get("provider") == "volume" - ), f"Got: {redis_config.get('provider', None)}" - args = redis_config["args"][0] - deployment["spec"]["template"]["spec"]["volumes"] = args["volumes"] - self.write_config("redis-master-Deployment.yaml", deployment) - - def write_secret(self): - assert self.bucket - assert self.project - secrets = copy.deepcopy(self.secret_template) - secrets["stringData"]["BUCKET"] = self.bucket - secrets["stringData"]["PROJECT"] = self.project - secrets["stringData"]["CS_CRYPT_KEY"] = workers_config.get( - "CS_CRYPT_KEY" - ) or self.secrets.get("CS_CRYPT_KEY") - redis_secrets = self.redis_secrets() - for name, sec in redis_secrets.items(): - if sec is not None: - secrets["stringData"][name] = sec - - self.write_config("secret.yaml", secrets) def write_cloudflare_api_token(self): api_token = self.secrets.get("CLOUDFLARE_API_TOKEN") @@ -302,34 +159,6 @@ def write_config(self, filename, config): with open(f"{self.kubernetes_target}/{filename}", "w") as f: f.write(yaml.dump(config)) - def redis_secrets(self): - """ - Return redis ACL user passwords. If they are not in the secret manager, - try to generate them using a local instance of redis. If this fails, - they are set to an empty string. - """ - if self._redis_secrets is not None: - return self._redis_secrets - from google.api_core import exceptions - - redis_secrets = dict( - REDIS_ADMIN_PW="", - REDIS_EXECUTOR_PW="", - REDIS_SCHEDULER_PW="", - REDIS_OUTPUTS_PW="", - ) - for sec in redis_secrets: - try: - value = self.secrets.get(sec) - except exceptions.NotFound: - try: - value = redis_acl_genpass() - self.secrets.set(sec, value) - except Exception: - value = "" - redis_secrets[sec] = value - return redis_secrets - @property def secrets(self): if self._secrets is None: @@ -365,11 +194,12 @@ def config_(args: argparse.Namespace): def port_forward(args: argparse.Namespace): - run("kubectl port-forward svc/scheduler 8888:80") + run("kubectl port-forward svc/workers_api 8888:80") def serve(args: argparse.Namespace): - scheduler.run() + # workers_api.run() + pass def cli(subparsers: argparse._SubParsersAction, config=None, **kwargs): diff --git a/workers/cs_workers/services/outputs_processor.py b/workers/cs_workers/services/outputs_processor.py index e5e2dc7b..49aff979 100644 --- a/workers/cs_workers/services/outputs_processor.py +++ b/workers/cs_workers/services/outputs_processor.py @@ -3,9 +3,12 @@ import os import httpx +from pydantic import BaseModel import redis -import tornado.ioloop -import tornado.web +from rq import Queue +from fastapi import FastAPI, Body +from .api.schemas import TaskComplete + try: from dask.distributed import Client @@ -14,109 +17,70 @@ import cs_storage -from cs_workers.services import auth -from cs_workers.utils import redis_conn_from_env +app = FastAPI() -redis_conn = dict( - username=os.environ.get("REDIS_USER"), - password=os.environ.get("REDIS_PW"), - **redis_conn_from_env(), +queue = Queue( + connection=redis.Redis( + host=os.environ.get("REDIS_HOST"), + port=os.environ.get("REDIS_PORT"), + password=os.environ.get("REDIS_PASSWORD"), + ) ) BUCKET = os.environ.get("BUCKET") -async def write(task_id, outputs): - async with await Client(asynchronous=True, processes=False) as client: - outputs = cs_storage.deserialize_from_json(outputs) - res = await client.submit(cs_storage.write, task_id, outputs) +class Result(BaseModel): + url: str + headers: dict + task: TaskComplete + + +def write(task_id, outputs): + outputs = cs_storage.deserialize_from_json(outputs) + res = cs_storage.write(task_id, outputs) return res -async def push(url, auth_headers, task_name, result): - async with httpx.AsyncClient(headers=auth_headers) as client: - if task_name == "sim": - print(f"posting data to {url}/outputs/api/") - return await client.put(f"{url}/outputs/api/", json=result) - if task_name == "parse": - print(f"posting data to {url}/inputs/api/") - return await client.put(f"{url}/inputs/api/", json=result) - else: - raise ValueError(f"Unknown task type: {task_name}.") - - -class Write(tornado.web.RequestHandler): - async def post(self): - print("POST -- /write/") - payload = json.loads(self.request.body.decode("utf-8")) - result = await write(**payload) - print("success-write") - self.write(result) - - -class Push(tornado.web.RequestHandler): - async def post(self): - print("POST -- /push/") - data = json.loads(self.request.body.decode("utf-8")) - job_id = data.get("result", {}).get("task_id", None) - if job_id is None: - print("missing job id") - self.set_status(400) - self.write(json.dumps({"error": "Missing job id."})) - return - - with redis.Redis(**redis_conn) as rclient: - data = rclient.get(f"jobinfo-{job_id}") - - if data is None: - print("Unknown job id: ", job_id) - self.set_status(400) - self.write(json.dumps({"error": "Unknown job id."})) - return - - jobinfo = json.loads(data.decode()) - print("got jobinfo", jobinfo) - cluster_user = jobinfo.get("cluster_user", None) - if cluster_user is None: - print("missing Cluster-User") - self.set_status(400) - self.write(json.dumps({"error": "Missing cluster_user."})) - return - user = auth.User.get(cluster_user) - if user is None: - print("unknown user", cluster_user) - self.set_status(404) - return - - print("got user", user.username, user.url) - - payload = json.loads(self.request.body.decode("utf-8")) - resp = await push(url=user.url, auth_headers=user.headers(), **payload) - print("got resp-push", resp.status_code, resp.url) - self.set_status(200) - - -def get_app(): - assert Client is not None, "Unable to import dask client" - assert auth.cryptkeeper is not None - assert BUCKET - return tornado.web.Application([(r"/write/", Write), (r"/push/", Push)]) - - -def start(args: argparse.Namespace): - if args.start: - app = get_app() - app.listen(8888) - tornado.ioloop.IOLoop.current().start() - - -def cli(subparsers: argparse._SubParsersAction): - parser = subparsers.add_parser( - "outputs-processor", - aliases=["outputs"], - description="REST API for processing and storing outputs.", - ) - parser.add_argument("--start", required=False, action="store_true") - parser.set_defaults(func=start) +def push(job_id: str, result: Result): + resp = None + if result.task.task_name == "sim": + print(f"posting data to {result.url}/outputs/api/") + result.task.outputs = write(job_id, result.task.outputs) + resp = httpx.put( + f"{result.url}/outputs/api/", + json=dict(job_id=job_id, **result.task.dict()), + headers=result.headers, + ) + elif result.task.task_name == "parse": + print(f"posting data to {result.url}/inputs/api/") + resp = httpx.put( + f"{result.url}/inputs/api/", + json=dict(job_id=job_id, **result.task.dict()), + headers=result.headers, + ) + elif result.task.task_name == "defaults": + print(f"posting data to {result.url}/model-config/api/") + resp = httpx.put( + f"{result.url}/model-config/api/", + json=dict(job_id=job_id, **result.task.dict()), + headers=result.headers, + ) + + if resp is not None and resp.status_code == 400: + print(resp.text) + resp.raise_for_status() + elif resp is not None: + resp.raise_for_status() + else: + raise ValueError( + f"resp is None for: {job_id} with name {result.task.task_name}" + ) + + +@app.post("/{job_id}/", status_code=200) +async def post(job_id: str, result: Result = Body(...)): + print("POST -- /", job_id) + queue.enqueue(push, job_id, result) diff --git a/workers/cs_workers/services/rq_settings.py b/workers/cs_workers/services/rq_settings.py new file mode 100644 index 00000000..ae9096bd --- /dev/null +++ b/workers/cs_workers/services/rq_settings.py @@ -0,0 +1,6 @@ +import os + +host = os.environ.get("REDIS_HOST") +port = os.environ.get("REDIS_PORT") +password = os.environ.get("REDIS_PASSWORD", None) +REDIS_URL = f"redis://:{password}@{host}:{port}/" diff --git a/workers/cs_workers/services/scheduler.py b/workers/cs_workers/services/scheduler.py index 6dbf1854..f2a5a850 100644 --- a/workers/cs_workers/services/scheduler.py +++ b/workers/cs_workers/services/scheduler.py @@ -55,8 +55,11 @@ async def post(self, owner, title): return payload = Payload().loads(self.request.body.decode("utf-8")) - if f"{owner}/{title}" not in self.config[self.user.username].projects(): + try: + project = self.config[self.user.username].get_project(owner, title) + except KeyError: self.set_status(404) + return task_id = payload.get("task_id") if task_id is None: @@ -88,7 +91,7 @@ async def post(self, owner, title): owner, title, tag=tag, - model_config=self.config[self.user.username], + model_config=project, job_id=task_id, job_kwargs=payload["task_kwargs"], rclient=self.rclient, @@ -146,7 +149,7 @@ def get(self, owner, title, deployment_name): owner=project["owner"], title=project["title"], tag=None, - model_config=self.config[self.user.username], + model_config=project, callable_name=project["callable_name"], deployment_name=deployment_name, incluster=incluster, @@ -174,7 +177,7 @@ def delete(self, owner, title, deployment_name): owner=project["owner"], title=project["title"], tag=None, - model_config=self.config[self.user.username], + model_config=project, callable_name=project["callable_name"], deployment_name=deployment_name, incluster=incluster, @@ -225,7 +228,7 @@ def post(self, owner, title): owner=project["owner"], title=project["title"], tag=data["tag"], - model_config=self.config[self.user.username], + model_config=project, callable_name=project["callable_name"], deployment_name=data["deployment_name"], incluster=incluster, diff --git a/workers/cs_workers/templates/services/outputs-processor-Deployment.template.yaml b/workers/cs_workers/templates/services/outputs-processor-Deployment.template.yaml deleted file mode 100755 index 8a5825e2..00000000 --- a/workers/cs_workers/templates/services/outputs-processor-Deployment.template.yaml +++ /dev/null @@ -1,56 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: outputs-processor -spec: - replicas: 1 - selector: - matchLabels: - app: outputs-processor - template: - metadata: - labels: - app: outputs-processor - spec: - serviceAccountName: outputs-processor - containers: - - name: outputs-processor - image: - ports: - - containerPort: 8888 - env: - - name: BUCKET - valueFrom: - secretKeyRef: - name: worker-secret - key: BUCKET - - name: REDIS_HOST - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_HOST - - name: REDIS_PORT - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_PORT - - name: REDIS_DB - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_DB - optional: true - - name: REDIS_USER - value: outputs - - name: REDIS_PW - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_OUTPUTS_PW - - name: CS_CRYPT_KEY - valueFrom: - secretKeyRef: - name: worker-secret - key: CS_CRYPT_KEY - nodeSelector: - component: api diff --git a/workers/cs_workers/templates/services/redis-master-Deployment.template.yaml b/workers/cs_workers/templates/services/redis-master-Deployment.template.yaml deleted file mode 100644 index daffb3bf..00000000 --- a/workers/cs_workers/templates/services/redis-master-Deployment.template.yaml +++ /dev/null @@ -1,48 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app: redis - name: redis-master -spec: - replicas: 1 - selector: - matchLabels: - app: redis - role: master - tier: backend - template: - metadata: - labels: - app: redis - role: master - tier: backend - spec: - containers: - - env: [] - command: ["redis-server", "--appendonly", "yes"] - image: # redis-python - lifecycle: - postStart: - exec: - command: - - python3 - - /home/redis_init.py - name: master - ports: - - containerPort: 6379 - resources: - requests: - cpu: 100m - memory: 100Mi - volumeMounts: - - mountPath: /data - name: redis-volume - volumes: - - name: redis-volume - # This GCE PD must already exist. - gcePersistentDisk: - pdName: redis-disk - fsType: ext4 - nodeSelector: - component: api diff --git a/workers/cs_workers/templates/services/scheduler-Deployment.template.yaml b/workers/cs_workers/templates/services/scheduler-Deployment.template.yaml deleted file mode 100755 index a49c0500..00000000 --- a/workers/cs_workers/templates/services/scheduler-Deployment.template.yaml +++ /dev/null @@ -1,56 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: scheduler -spec: - replicas: 1 - selector: - matchLabels: - app: scheduler - template: - metadata: - labels: - app: scheduler - spec: - serviceAccountName: scheduler - containers: - - name: scheduler - image: - ports: - - containerPort: 8888 - env: - - name: PROJECT - valueFrom: - secretKeyRef: - name: worker-secret - key: PROJECT - - name: CS_CRYPT_KEY - valueFrom: - secretKeyRef: - name: worker-secret - key: CS_CRYPT_KEY - - name: REDIS_HOST - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_HOST - - name: REDIS_PORT - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_PORT - - name: REDIS_DB - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_DB - optional: true - - name: REDIS_USER - value: scheduler - - name: REDIS_PW - valueFrom: - secretKeyRef: - name: worker-secret - key: REDIS_SCHEDULER_PW - nodeSelector: - component: api diff --git a/workers/cs_workers/templates/services/scheduler-Service.yaml b/workers/cs_workers/templates/services/scheduler-Service.yaml deleted file mode 100644 index 674835ae..00000000 --- a/workers/cs_workers/templates/services/scheduler-Service.yaml +++ /dev/null @@ -1,11 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: scheduler -spec: - ports: - - port: 80 - targetPort: 8888 - selector: - app: scheduler - type: LoadBalancer diff --git a/workers/cs_workers/templates/services/scheduler-ingressroute.template.yaml b/workers/cs_workers/templates/services/scheduler-ingressroute.template.yaml deleted file mode 100644 index e64f2714..00000000 --- a/workers/cs_workers/templates/services/scheduler-ingressroute.template.yaml +++ /dev/null @@ -1,16 +0,0 @@ -apiVersion: traefik.containo.us/v1alpha1 -kind: IngressRoute -metadata: - name: scheduler-tls - namespace: default -spec: - entryPoints: - - websecure - routes: - - match: - kind: Rule - services: - - name: scheduler - port: 80 - tls: - certResolver: myresolver diff --git a/workers/requirements.txt b/workers/requirements.txt index c490b75f..1dffa2ac 100755 --- a/workers/requirements.txt +++ b/workers/requirements.txt @@ -9,3 +9,6 @@ pyyaml google-cloud-secret-manager cs-crypt>=0.0.2 pyjwt +uvicorn[standard] +fastapi +pydantic[email,dotenv] \ No newline at end of file diff --git a/workers/setup.py b/workers/setup.py index 77ba1226..c7d23092 100644 --- a/workers/setup.py +++ b/workers/setup.py @@ -28,6 +28,9 @@ "tornado", "cs-storage>=1.11.0", "docker", + "pydantic[email,dotenv]", + "fastapi", + "rq", ], include_package_data=True, entry_points={