Wrapper to validate arguments and serialize return values using Pydantic.
(
app: "Celery",
task_fun: typing.Callable[..., typing.Any],
task_name: str,
strict: bool = True,
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
)
| 121 | |
| 122 | |
| 123 | def pydantic_wrapper( |
| 124 | app: "Celery", |
| 125 | task_fun: typing.Callable[..., typing.Any], |
| 126 | task_name: str, |
| 127 | strict: bool = True, |
| 128 | context: typing.Optional[typing.Dict[str, typing.Any]] = None, |
| 129 | dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None |
| 130 | ): |
| 131 | """Wrapper to validate arguments and serialize return values using Pydantic.""" |
| 132 | try: |
| 133 | pydantic = importlib.import_module('pydantic') |
| 134 | except ModuleNotFoundError as ex: |
| 135 | raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex |
| 136 | |
| 137 | BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking |
| 138 | |
| 139 | if context is None: |
| 140 | context = {} |
| 141 | if dump_kwargs is None: |
| 142 | dump_kwargs = {} |
| 143 | dump_kwargs.setdefault('mode', 'json') |
| 144 | |
| 145 | # If a file uses `from __future__ import annotations`, all annotations will |
| 146 | # be strings. `typing.get_type_hints()` can turn these back into real |
| 147 | # types, but can also sometimes fail due to circular imports. Try that |
| 148 | # first, and fall back to annotations from `inspect.signature()`. |
| 149 | task_signature = inspect.signature(task_fun) |
| 150 | |
| 151 | try: |
| 152 | type_hints = typing.get_type_hints(task_fun) |
| 153 | except (NameError, AttributeError, TypeError): |
| 154 | # Fall back to raw annotations from inspect if get_type_hints fails |
| 155 | type_hints = None |
| 156 | |
| 157 | @functools.wraps(task_fun) |
| 158 | def wrapper(*task_args, **task_kwargs): |
| 159 | # Validate task parameters if type hinted as BaseModel |
| 160 | bound_args = task_signature.bind(*task_args, **task_kwargs) |
| 161 | for arg_name, arg_value in bound_args.arguments.items(): |
| 162 | if type_hints and arg_name in type_hints: |
| 163 | arg_annotation = type_hints[arg_name] |
| 164 | else: |
| 165 | arg_annotation = task_signature.parameters[arg_name].annotation |
| 166 | |
| 167 | optional_arg = get_optional_arg(arg_annotation) |
| 168 | if optional_arg is not None and arg_value is not None: |
| 169 | arg_annotation = optional_arg |
| 170 | |
| 171 | if annotation_issubclass(arg_annotation, BaseModel): |
| 172 | bound_args.arguments[arg_name] = arg_annotation.model_validate( |
| 173 | arg_value, |
| 174 | strict=strict, |
| 175 | context={**context, 'celery_app': app, 'celery_task_name': task_name}, |
| 176 | ) |
| 177 | |
| 178 | # Call the task with (potentially) converted arguments |
| 179 | returned_value = task_fun(*bound_args.args, **bound_args.kwargs) |
| 180 |
no test coverage detected