Check the argument type annotation when executing a function in Python and make an error

Introduction

As is known to Python users, there is no type enforcement in Python. I couldn't even write a type until the type annotation (typing) was added in Python 3.5. You can only read the air with comments and variable names. (Personally, the dictionary was the worst)

So it's been really helpful since the type annotations were implemented. It's already a habit, so it feels strange to not write it. (I think there are many people who have the same feelings lol)

However, it is still not enforceable and can only be checked with mypy or the VS Code extension Pylance. If you provide it as a module to a third party, you cannot check the type of IF, so you have to take care in the implementation. It is troublesome and not essential to write such care (check processing) for each one. My motivation is to write concisely.

approach

Python has a handy feature called decorators. Decorators allow you to perform processing before executing a function, so The approach is to perform the check process in this.

So for each function, it's OK to write only the decorator.

Implementation

Define the following decorator function. You can specify Error when there is a mismatch of type annotation in the error argument. check_all_collection can specify whether to check all items when checking the argument of Collection type.

"""
Decorator definition file to check argument types
"""
import functools
import inspect
from typing import Any, Union, Callable, _GenericAlias


def check_args_type(error: Exception = TypeError, check_all_collection: bool = False):
    """
Decorator function that checks if the argument type matches the annotation type
    Args:
        error:Error class when there is a mismatch
        check_all_collection:Whether to check all the contents of the collection type
    """
    def _decorator(func: Callable):
        @functools.wraps(func)
        def args_type_check_wrapper(*args, **kwargs):
            sig = inspect.signature(func)
            try:
                for arg_key, arg_val in sig.bind(*args, **kwargs).arguments.items():
                    #Annotation is not a type/Do not judge if empty
                    annotation = sig.parameters[arg_key].annotation
                    if not isinstance(annotation, type) and not isinstance(annotation, _GenericAlias):
                        continue
                    if annotation == inspect._empty:
                        continue

                    #Match judgment
                    #If it is a Generic type, it is OK if the derivative form and part match
                    is_match = __check_generic_alias(annotation, arg_val, check_all_collection)
                    if not is_match:
                        message = f"argument'{arg_key}'The type of is incorrect. annotaion:{annotation} request:{type(arg_val)}"
                        raise error(message)
            except TypeError as exc:
                raise error("The argument types or numbers do not match.") from exc
            return func(*args, **kwargs)
        return args_type_check_wrapper
    return _decorator

def __check_generic_alias(
    annotation: Union[_GenericAlias, type],
    request: Any,
    check_all_collection: bool = False
):
    """
Generic Alias ​​type check
    Args:
        annotation:Annotation type
        request:request
        check_all_collection:Whether to check all the contents of the collection type
    """
    #No type check for Any
    if annotation == Any:
        return True

    #Type check
    request_type = type(request)
    if isinstance(annotation, _GenericAlias):
        if annotation.__origin__ == request_type:    # for collection ...list, dict, set
            # -----------
            # list
            # -----------
            if annotation.__origin__ == list and request:
                _annotation = annotation.__args__[0]
                if check_all_collection:    #Check one by one when checking all items
                    for _request in request:
                        is_match = __check_generic_alias(
                            _annotation, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    return True

                else:   #If not all items are checked, take out the beginning and check
                    return __check_generic_alias(
                        _annotation, request[0], check_all_collection
                    )

            # -----------
            # dict
            # -----------
            if annotation.__origin__ == dict and request:
                _annotation_key = annotation.__args__[0]
                _annotation_value = annotation.__args__[1]
                if check_all_collection:    #Check one by one when checking all items
                    for _request in request.keys():
                        is_match = __check_generic_alias(
                            _annotation_key, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    for _request in request.values():
                        is_match = __check_generic_alias(
                            _annotation_value, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    return True

                else:   #If not all items are checked, take out the beginning and check
                    is_match_key = __check_generic_alias(
                        _annotation_key, list(request.keys())[0], check_all_collection
                    )
                    is_match_value = __check_generic_alias(
                        _annotation_value, list(request.values())[0], check_all_collection
                    )
                    is_match = is_match_key and is_match_value
                    return is_match

            #If the contents do not exist, it is OK if there is origin
            if not request:
                return True

        else:
            # list/In the case of dict, if the origin does not match, an error will occur.
            origin = annotation.__origin__
            if origin == list or origin == dict:
                return False
            #Check recursively otherwise
            else:
                for arg in annotation.__args__:
                    is_match = __check_generic_alias(arg, request)
                    if is_match:
                        return True
    else:
        #Bool is a subclass of int, so issubclass becomes True
        #I want to make it NG because the meaning is originally different
        if request_type == bool and annotation == int:
            return False
        return issubclass(request_type, annotation)
    return False

The usage example is part 1.

#The simplest pattern
@check_args_type()
def test(value: int, is_valid: bool) -> float:
    """
(abridgement)
    """
    return 0.0

def main():
    # OK
    result = test(5, True)

    # NG -> TypeError
    result = test(0.0, False)

    # NG2 -> TypeError
    result = test(1, "True")

Usage example # 2.

#A pattern to check all the contents of the Collection
@check_args_type(check_all_collection=True)
def test2(value: List[int]) -> List[float]:
    """
(abridgement)
    """
    return [0.0]

def main():
    # OK
    result = test2([0, 5, 10, 20])

    # NG -> TypeError
    result = test([0.0, 5.0, 10.0, 20.0])

    # NG2 -> TypeError
    result = test([0, 5, "test"])

I think there are types that are not well considered, such as Enums and generators, but I think that basic types can be covered. (Please add if necessary)

Summary

I introduced how to check the argument type annotation when executing a function and make an error. With this, the force of the mold can be exerted. I think it can be used in situations where strictness is required (boundaries such as IF).

PS) It would be great if we could check the values ​​like contract programming, so we plan to expand it.

Recommended Posts

Check the argument type annotation when executing a function in Python and make an error
[Python] I want to know the variables in the function when an error occurs!
Master the type in Python? (When should type check be done)
Function argument type definition in python
[Python] Make the function a lambda function
Specify your own class in class argument and return type annotation in Python
Precautions when pickling a function in python
The eval () function that calculates a string as an expression in python
When accessing a URL containing Japanese (Japanese URL) with python3, it will be encoded in html without permission and an error will occur, so make a note of the workaround.
I get an error when I put a Python plugin in Visual Studio Code under the pyenv environment
Get the caller of a function in Python
Make a copy of the list in Python
Make a Python program a daemon and run it automatically when the OS starts
In the Chainer tutorial, I get an error when importing a package. (mock)
[Python] What to check when you get a Unicode Decode Error in Django
Check if the string is a number in python
When you get an error in python scraping (requests)
[Python] Make sure the received function is a user-defined function
What does the last () in a function mean in Python?
I checked the reference speed when using python list, dictionary, and set type in.
Note: The meaning of specifying only * (asterisk) as an argument in the Python function definition.
A note that runs an external program in Python and parses the resulting line
A convenient function memo to use when you want to enter the debugger if an error occurs when running a Python script.
[C / C ++] Pass the value calculated in C / C ++ to a python function to execute the process, and use that value in C / C ++.
I got an error in vim and zsh in Python 3.7 series
Associate Python Enum with a function and make it Callable
Get the MIME type in Python and determine the file format
Make a bookmarklet in Python
I also tried to imitate the function monad and State monad with a generator in Python
If you want to put an argument in the closure function and execute it later
The story I was addicted to when I specified nil as a function argument in Go
An easy way to view the time taken in Python and a smarter way to improve it
Write a script in Shell and Python to notify you in Slack when the process is finished
I tried to find out the difference between A + = B and A = A + B in Python, so make a note
Be careful when specifying the default argument value in Python3 series
Resolved an error when putting pygame in python3 on raspberry pi
How to make a string into an array or an array into a string in Python
Format when passing a long string as an argument of python
I got an AttributeError when mocking the open method in python
How to check the memory size of a variable in Python
[Python] Execution time when a function is entered in a dictionary value
Timezone specification when converting a string to datetime type in python
How to check the memory size of a dictionary in Python
A function that measures the processing time of a method in python
Utilization of lambda (when passing a function as an argument of another function)
Get the formula in an excel file as a string in Python
What to do when the value type is ambiguous in Python?
Check and move directories in Python
Function synthesis and application in Python
When writing a program in Python
Set up a dummy SMTP server in Python and check the operation of sending from Action Mailer
An error that does not work as expected when calling the tkinter module in a text editor
I want to pass an argument to a python function and execute it from PHP on a web server
[Python] I tried to summarize the set type (set) in an easy-to-understand manner.
Comparing the basic grammar of Python and Go in an easy-to-understand manner
Check the in-memory bytes of a floating point number float in Python
When you make a mistake in the directory where you execute `pipenv install`
Open an Excel file in Python and color the map of Japan
I made a program to check the size of a file in Python
A useful note when using Python for the first time in a while
If you get a no attribute error in boto3, check the version