diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a2dd4d3ea0..45056e8be8 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -291,10 +291,7 @@ def get_input_create_model_kwargs( order = 0 for name, parameter in signature.parameters.items(): - if name not in input_types: - raise TypeError(f"No input type provided for parameter `{name}`.") - - InputType = input_types[name] # pylint: disable=invalid-name + InputType = input_types.get(name, parameter.annotation) validate_input_type(InputType, name) @@ -360,7 +357,10 @@ class Input(BaseModel): predict = get_predict(predictor) signature = inspect.signature(predict) - input_types = get_type_hints(predict) + try: + input_types = get_type_hints(predict) + except TypeError: + input_types = {} if "return" in input_types: del input_types["return"] @@ -374,16 +374,25 @@ class Input(BaseModel): ) # type: ignore +def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]: + try: + return get_type_hints(fn).get("return", None) + except TypeError: + return_annotation = inspect.signature(fn).return_annotation + if return_annotation is inspect.Signature.empty: + return None + return return_annotation + + def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ Creates a Pydantic Output model from the return type annotation of a Predictor's predict() method. """ predict = get_predict(predictor) + maybe_output_type = get_return_annotation(predict) - input_types = get_type_hints(predict) - - if "return" not in input_types: + if maybe_output_type is None: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -398,8 +407,8 @@ def predict( ... """ ) - OutputType = input_types.pop("return") # pylint: disable=invalid-name - + # we need the indirection to narrow the type to the one that's declared later + OutputType = maybe_output_type # pylint: disable=invalid-name # The type that goes in the response is a list of the yielded type if get_origin(OutputType) is Iterator: # Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator @@ -462,7 +471,10 @@ class TrainingInput(BaseModel): train = get_train(predictor) signature = inspect.signature(train) - input_types = get_type_hints(train) + try: + input_types = get_type_hints(train) + except TypeError: + input_types = {} if "return" in input_types: del input_types["return"] @@ -483,8 +495,8 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: train = get_train(predictor) - input_types = get_type_hints(train) - if "return" not in input_types: + TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name + if not TrainingOutputType: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -500,8 +512,6 @@ def train( """ ) - TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name - name = ( TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else "" ) diff --git a/test-integration/test_integration/fixtures/partial-predict-project/predict.py b/test-integration/test_integration/fixtures/partial-predict-project/predict.py index 0e24def50c..e604c990b3 100644 --- a/test-integration/test_integration/fixtures/partial-predict-project/predict.py +++ b/test-integration/test_integration/fixtures/partial-predict-project/predict.py @@ -1,5 +1,6 @@ -from typing import Callable import functools +import inspect +from typing import Any, Callable from cog import BasePredictor, Input @@ -10,7 +11,7 @@ def general( ) -> int: return 1 - def _remove(f: Callable, defaults: dict[str, Any]) -> Callable: + def _remove(f: Callable, defaults: "dict[str, Any]") -> Callable: # pylint: disable=no-self-argument def wrapper(self, *args, **kwargs): kwargs.update(defaults) @@ -31,7 +32,7 @@ def wrapper(self, *args, **kwargs): predict = _remove(general, {"system_prompt": ""}) -def _train(self, prompt: str = Input(description="hi"), system_prompt: str = None): +def _train(prompt: str = Input(description="hi"), system_prompt: str = None): return 1