[Python3] Rewrite the code object of the function

Introduction

In Python, it's relatively easy to overwrite a method or a function in a module. Qiita also has some articles.

-Override library functions in Python -Overwrite library processing with Python

These techniques just tweak the attribute access as follows: I'm not playing with a function as an object.

class SomeClass:
    def original_method(self):
        print('call original_method')

def new_method(self):
    print('call new_method')

some_instance1 = SomeClass()
some_instance2 = SomeClass()

# some_Rewrite method of instance1
some_instance1.original_method = type(some_instance1.original_method)(new_method, some_instance1)  
some_instance1.original_method()  # new_method()Is called
some_instance2.original_method()  # original_method()Is called

#Rewrite methods for all instances
SomeClass.original_method = new_method
some_instance1.original_method()  # new_method()Is called
some_instance2.original_method()  # new_method()Is called
import some_module

def new_func():
    print('call new_func')

##Rewrite the function in the module
some_module.original_func = new_func
some_module.original_func()  # new_func()Is called

Most of the time, these techniques are pretty good. However, there are some cases where you cannot overwrite even if you play with attribute access.

#If it was retrieved before overwriting the attribute,
original_method = some_instance1.original_method

#Even if you overwrite the attribute
type(some_instance1).original_method = new_method

#There is no effect on the one taken out first
original_method()  #Original original_method()Is called
import some_module
from some_module import original_func

#Same for functions in modules
some_module.original_func = new_func
original_func()  # original_func()Is called

What you want to do (how to use)

I want to overwrite the entire program, even if the attributes have been retrieved first.

import some_module
from some_module import original_func  #Even if it was taken out first

def new_func():
    print('call new_func')

overwrite_func(some_module.original_func, new_func)  #Overwrite later
original_func()  #Here new_func()I want you to be called

Made (prototype)

def overwrite_func(orig, new):
    from uuid import uuid4
    kw = 'kw' + str(uuid4()).replace('-', '')
    exec("def outer():\n " + '='.join(list(orig.__code__.co_freevars) + ['None']) 
         + "\n def inner(*args, " + kw + "=new, **kwargs):\n  " 
         + ','.join(orig.__code__.co_freevars)
         + "\n  return " + kw + "(*args, **kwargs)\n return inner",
         locals(), globals())
    inner = outer()
    orig.__code__ = inner.__code__
    orig.__defaults__ = inner.__defaults__
    orig.__kwdefaults__ = inner.__kwdefaults__

At first I thought I would just overwrite __code__, but the number of __code__.co_freevars (the number of variables in the outer function that the function defined inside the function uses internally?) Is one. It seemed that I couldn't substitute it if I didn't do it, so I adjusted the number of freevars with ʻexec`.

Made (completed version)

If it is a prototype version, the signature will be lost, so I left it as much as possible. However, __overwrite_func is added to the keyword argument to adjust __code__.co_freevars.

def overwrite_func(orig, new, signature=None):
    import inspect
    from types import FunctionType
    from textwrap import dedent
    assert isinstance(orig, FunctionType), (orig, type(orig))
    assert isinstance(new, FunctionType), (new, type(new))
    if signature is None:
        signature = inspect.signature(orig)
    params = [
        (str(p).split(':')[0].split('=')[0], p)
        for k, p in signature.parameters.items()
        if k != '__overwrite_func'
    ]
    default = {p.name: p.default for _, p in params}
    anno = {p.name: p.annotation for _, p in params}
    args_kwargs = [
        k if k[0] == '*' or p.kind == p.POSITIONAL_ONLY else k + '=' + k 
        for k, p in params
    ]
    signature_ = [
        (k + (':anno["' + k + '"]' if p.annotation != p.empty else '')
         + ('=default["' + k + '"]' if p.default != p.empty else ''),
         not (p.kind == p.VAR_KEYWORD or p.kind == p.KEYWORD_ONLY))
        for k, p in params
    ]
    signature__ = [s for s, positional in signature_ if positional]
    signature__.append('__overwrite_func=new')
    signature__.extend(s for s, positional in signature_ if not positional)
    signature__ = '(' + ', '.join(signature__) + ')'
    if signature.return_annotation is not inspect.Signature.empty:
        anno['return'] = signature.return_annotation
        signature__ += ' -> anno["return"]'
    source = dedent("""
    def outer():
        """ + '='.join(list(orig.__code__.co_freevars) + ['None']) + """
        def inner""" + signature__ + """:
            """ + ', '.join(orig.__code__.co_freevars) + """
            return __overwrite_func(""" + ', '.join(args_kwargs) + """)
        return inner
    """)
    globals_ = {}
    exec(source, dict(new=new, default=default, anno=anno), globals_)
    inner = globals_['outer']()
    globals_.clear()
    orig.__code__ = inner.__code__
    orig.__defaults__ = inner.__defaults__
    orig.__kwdefaults__ = inner.__kwdefaults__
    orig.__annotations__ = inner.__annotations__

Caution

The function I made this time is not universal. It is powerless for special functions that do not have __code__ or callable objects that implement __call__. Please use it according to the other party.

overwrite_func(print, new_func)  #assert is disabled
# → AttributeError: 'builtin_function_or_method' object has no attribute '__code__'

Be careful of memory leaks, as references to overwriting functions will accumulate in __overwrite_func.

Application example

def copy_func(f):
    """https://stackoverflow.com/questions/13503079"""
    import functools
    import types
    assert isinstance(f, types.FunctionType), (f, type(f))
    g = types.FunctionType(
        f.__code__,
        f.__globals__, 
        name=f.__name__,
        argdefs=f.__defaults__,
        closure=f.__closure__,
    )
    g.__kwdefaults__ = f.__kwdefaults__
    functools.update_wrapper(g, f)
    return g

def add_hook(func, pre_call=None, post_call=None, except_=None, finally_=None):
    import inspect
    func_sig = inspect.signature(func)
    func_copy = copy_func(func)

    def hook(*args, **kwargs):
        bound_args = func_sig.bind(*args, **kwargs)
        if pre_call is not None:
            pre_call(func_copy, bound_args)
        try:
            return_ = func_copy(*args, **kwargs)
        except BaseException as e:
            if except_ is not None:
                except_(func_copy, bound_args, e)
            raise
        else:
            if post_call is not None:
                post_call(func_copy, bound_args, return_)
        finally:
            if finally_ is not None:
                finally_(func_copy, bound_args)
        return return_

    overwrite_func(func, hook)

You can retrofit the callback function.

def callback(f, args, result):
    print(result)

add_hook(original_func, post_call=callback)
original_func()  # original_func()Call before callback()Is called.

in conclusion

I was able to do it, but if I don't have to do it, I shouldn't do it. I don't fully understand the specifications around __code__, so I probably don't have enough test cases. Please let me know if there is a case that does not work.

Recommended Posts

[Python3] Rewrite the code object of the function
[python] Value of function object (?)
[Python] Get the character code of the file
Rewrite Python2 code to Python3 (2to3)
the zen of Python
[Python] Read the source code of Bottle Part 2
Get the caller of a function in Python
[Python] Read the source code of Bottle Part 1
Code for checking the operation of Python Matplotlib
Convert the character code of the file with Python3
Towards the retirement of Python2
[Python] Calculate the average value of the pixel value RGB of the object
[Python] Etymology of python function names
About the ease of Python
Let's break down the basics of TensorFlow Python code
Explain the code of Tensorflow_in_ROS
About the enumerate function (python)
Get the return code of the Python script from bat
# Function that returns the character code of a string
2.x, 3.x character code of python
About the features of Python
The Power of Pandas: Python
[Python] Change the Cache-Control of the object uploaded to Cloud Storage
Try to get the function list of Python> os package
Have the equation graph of the linear function drawn in Python
Check the scope of local variables with the Python locals function.
[Python3] Call by dynamically specifying the keyword argument of the function
Rewrite the record addition node of SPSS Modeler with Python.
The process of making Python code object-oriented and improving it
2015-11-26 python> Display the function list of the module> import math> dir (math)
The story of Python and the story of NaN
[Python] The stumbling block of import
First Python 3 ~ The beginning of repetition ~
Existence from the viewpoint of Python
pyenv-change the python version of virtualenv
Get the attributes of an object
Change the Python version of Homebrew
[Python] Make the function a lambda function
[Python] Understanding the potential_field_planning of Python Robotics
Review of the basics of Python (FizzBuzz)
Touch the object of the neural network
The meaning of ".object" in Django
[Python] Read the Flask source code
About the basics list of Python basics
[OpenCV; Python] Summary of findcontours function
Learn the basics of Python ① Beginners
python function ①
How to know the internal structure of an object in Python
Let's measure the test coverage of pushed python code on GitHub.
[Cloudian # 9] Try to display the metadata of the object in Python (boto3)
Try transcribing the probability mass function of the binomial distribution in Python
[Python] function
First python ② Try to write code while examining the features of python
I wrote the code to write the code of Brainf * ck in python
A function that measures the processing time of a method in python
[Python3] Define a decorator to measure the execution time of a function
Let's summarize the degree of coupling between modules with Python code
Rewrite the sampling node of SPSS Modeler with Python (2): Layered sampling, cluster sampling
[Python] A simple function to find the center coordinates of a circle
python function ②
Installation of Visual studio code and installation of python