From 09d903236d452cb85417888fa32193b963f0fe34 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 28 Jun 2024 21:40:02 +0300 Subject: [PATCH] fix(fal_client): handle missing metrics (#253) * fix(fal_client): handle missing metrics * chore: enable gha tests * don't forget pillow * add FAL_KEY_PROD --- .github/workflows/tests-fal-client.yml | 28 ++++++++++++ projects/fal_client/pyproject.toml | 1 + projects/fal_client/src/fal_client/client.py | 4 +- projects/fal_client/tests/__init__.py | 0 projects/fal_client/tests/unit/__init__.py | 0 projects/fal_client/tests/unit/test_client.py | 43 +++++++++++++++++++ 6 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/tests-fal-client.yml create mode 100644 projects/fal_client/tests/__init__.py create mode 100644 projects/fal_client/tests/unit/__init__.py create mode 100644 projects/fal_client/tests/unit/test_client.py diff --git a/.github/workflows/tests-fal-client.yml b/.github/workflows/tests-fal-client.yml new file mode 100644 index 00000000..d7244333 --- /dev/null +++ b/.github/workflows/tests-fal-client.yml @@ -0,0 +1,28 @@ +name: Run fal-client tests + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install dependencies + run: | + pip install --upgrade pip wheel + pip install -e 'projects/fal_client[test]' + - name: Run tests + env: + FAL_KEY: ${{ secrets.FAL_KEY_PROD }} + run: | + pytest projects/fal_client/tests \ No newline at end of file diff --git a/projects/fal_client/pyproject.toml b/projects/fal_client/pyproject.toml index c171a7cf..f447f92a 100644 --- a/projects/fal_client/pyproject.toml +++ b/projects/fal_client/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ test = [ "pytest", "pytest-asyncio", + "pillow", ] dev = [ "fal_client[test]", diff --git a/projects/fal_client/src/fal_client/client.py b/projects/fal_client/src/fal_client/client.py index ad467280..40d19cf2 100644 --- a/projects/fal_client/src/fal_client/client.py +++ b/projects/fal_client/src/fal_client/client.py @@ -72,7 +72,9 @@ def _parse_status(self, data: AnyJSON) -> Status: elif data["status"] == "IN_PROGRESS": return InProgress(logs=data["logs"]) elif data["status"] == "COMPLETED": - return Completed(logs=data["logs"], metrics=data["metrics"]) + # NOTE: legacy apps might not return metrics + metrics = data.get("metrics", {}) + return Completed(logs=data["logs"], metrics=metrics) else: raise ValueError(f"Unknown status: {data['status']}") diff --git a/projects/fal_client/tests/__init__.py b/projects/fal_client/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/fal_client/tests/unit/__init__.py b/projects/fal_client/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/fal_client/tests/unit/test_client.py b/projects/fal_client/tests/unit/test_client.py new file mode 100644 index 00000000..e42904de --- /dev/null +++ b/projects/fal_client/tests/unit/test_client.py @@ -0,0 +1,43 @@ +import pytest + +from fal_client.client import Queued, InProgress, Completed, _BaseRequestHandle + + +@pytest.mark.parametrize( + "data, result, raised", + [ + ( + {"status": "IN_QUEUE", "queue_position": 123}, + Queued(position=123), + False, + ), + ( + {"status": "IN_PROGRESS", "logs": [{"msg": "foo"}, {"msg": "bar"}]}, + InProgress(logs=[{"msg": "foo"}, {"msg": "bar"}]), + False, + ), + ( + {"status": "COMPLETED", "logs": [{"msg": "foo"}, {"msg": "bar"}]}, + Completed(logs=[{"msg": "foo"}, {"msg": "bar"}], metrics={}), + False, + ), + ( + {"status": "COMPLETED", "logs": [{"msg": "foo"}, {"msg": "bar"}], "metrics": {"m1": "v1", "m2": "v2"}}, + Completed(logs=[{"msg": "foo"}, {"msg": "bar"}], metrics={"m1": "v1", "m2": "v2"}), + False, + ), + ( + {"status": "FOO"}, + ValueError, + True, + ) + ] +) +def test_parse_status(data, result, raised): + handle = _BaseRequestHandle("foo", "bar", "baz", "qux") + + if raised: + with pytest.raises(result): + handle._parse_status(data) + else: + assert handle._parse_status(data) == result