From cc545ae2338882ace61104c946b6db682e2e92d1 Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 30 Aug 2024 14:24:01 -0400 Subject: [PATCH] Revert "Handle predictors with deferred annotations (#1772)" This reverts commit 05900a7e6d8270df8175b1f7788f0b566486855c. This is needed to avoid breaking predictors that rely on __signature__ or partial. --- python/cog/predictor.py | 40 ++++++------------- .../future-annotations-project/predict.py | 8 ---- .../test_integration/test_predict.py | 30 -------------- 3 files changed, 13 insertions(+), 65 deletions(-) delete mode 100644 test-integration/test_integration/fixtures/future-annotations-project/predict.py diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a2dd4d3ea0..50754f2ac7 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -18,7 +18,6 @@ Type, Union, cast, - get_type_hints, ) try: @@ -283,18 +282,13 @@ def validate_input_type( ) -def get_input_create_model_kwargs( - signature: inspect.Signature, input_types: Dict[str, Any] -) -> Dict[str, Any]: +def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]: create_model_kwargs = {} order = 0 for name, parameter in signature.parameters.items(): - if name not in input_types: - raise TypeError(f"No input type provided for parameter `{name}`.") - - InputType = input_types[name] # pylint: disable=invalid-name + InputType = parameter.annotation validate_input_type(InputType, name) @@ -360,17 +354,13 @@ class Input(BaseModel): predict = get_predict(predictor) signature = inspect.signature(predict) - input_types = get_type_hints(predict) - if "return" in input_types: - del input_types["return"] - return create_model( "Input", __config__=None, __base__=BaseInput, __module__=__name__, __validators__=None, - **get_input_create_model_kwargs(signature, input_types), + **get_input_create_model_kwargs(signature), ) # type: ignore @@ -380,10 +370,9 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ predict = get_predict(predictor) - - input_types = get_type_hints(predict) - - if "return" not in input_types: + signature = inspect.signature(predict) + OutputType: Type[BaseModel] + if signature.return_annotation is inspect.Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -398,7 +387,8 @@ def predict( ... """ ) - OutputType = input_types.pop("return") # pylint: disable=invalid-name + else: + OutputType = signature.return_annotation # The type that goes in the response is a list of the yielded type if get_origin(OutputType) is Iterator: @@ -462,17 +452,13 @@ class TrainingInput(BaseModel): train = get_train(predictor) signature = inspect.signature(train) - input_types = get_type_hints(train) - if "return" in input_types: - del input_types["return"] - return create_model( "TrainingInput", __config__=None, __base__=BaseInput, __module__=__name__, __validators__=None, - **get_input_create_model_kwargs(signature, input_types), + **get_input_create_model_kwargs(signature), ) # type: ignore @@ -482,9 +468,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ train = get_train(predictor) + signature = inspect.signature(train) - input_types = get_type_hints(train) - if "return" not in input_types: + if signature.return_annotation is inspect.Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -499,8 +485,8 @@ def train( ... """ ) - - TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name + else: + TrainingOutputType = signature.return_annotation name = ( TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else "" diff --git a/test-integration/test_integration/fixtures/future-annotations-project/predict.py b/test-integration/test_integration/fixtures/future-annotations-project/predict.py deleted file mode 100644 index 791d2218fd..0000000000 --- a/test-integration/test_integration/fixtures/future-annotations-project/predict.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -from cog import BasePredictor, Input - - -class Predictor(BasePredictor): - def predict(self, input: str = Input(description="Who to greet")) -> str: - return "hello " + input diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 34b81d3891..d6da057a96 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -288,33 +288,3 @@ def test_predict_path_list_input(tmpdir_factory): ) assert "test1" in result.stdout assert "test2" in result.stdout - - -def test_predict_works_with_deferred_annotations(): - project_dir = Path(__file__).parent / "fixtures/future-annotations-project" - - subprocess.check_call( - ["cog", "predict", "-i", "input=world"], - cwd=project_dir, - timeout=DEFAULT_TIMEOUT, - ) - - -def test_predict_int_none_output(): - project_dir = Path(__file__).parent / "fixtures/int-none-output-project" - - subprocess.check_call( - ["cog", "predict"], - cwd=project_dir, - timeout=DEFAULT_TIMEOUT, - ) - - -def test_predict_string_none_output(): - project_dir = Path(__file__).parent / "fixtures/string-none-output-project" - - subprocess.check_call( - ["cog", "predict"], - cwd=project_dir, - timeout=DEFAULT_TIMEOUT, - )