开发启发式来测试简单的匿名Python函数的等价性 [英] Developing a heuristic to test simple anonymous Python functions for equivalency

查看:151
本文介绍了开发启发式来测试简单的匿名Python函数的等价性的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我知道Python 3中的函数比较是如何工作的(只是比较内存中的地址),并且我明白为什么。我也明白true比较函数 f g 返回给定相同参数的相同结果,对于任何参数?)实际上是不可能的。 p>

我正在寻找中间的东西。我想比较工作在相同功能的最简单情况下,可能还有一些不太重要的情况:

  lambda x:x = = lambda x:x#True 
lambda x:2 * x == lambda y:2 * y#True
lambda x:2 * x == lambda x:x * 2#True或False是很好,但必须稳定
lambda x:2 * x == lambda x:x + x#True or False is fine,but must be stable

请注意,我对解决匿名函数( lambda )的这个问题感兴趣,但是不介意if该解决方案也适用于命名函数。



这样做的动机是在 blist 模块中,这将是很好的验证两个 sortedset 实例在对它们执行联合等操作之前具有相同的排序功能。

命名函数的兴趣不大,因为当它们不相同时,我可以假定它们是不同的。毕竟,假设某人在参数中创建了具有命名函数的两个排序集。如果他们打算让这些实例为设置操作的目的而兼容,那么他们可能会使用相同的函数,而不是两个单独的执行相同操作的命名函数。



<我只能想到三种方法。他们都看起来很难,所以任何想法都很感激。
$ b


  1. 比较字节码可能有效,但它可能会令人烦恼,因为它依赖于实现(因此,另外)。

  2. 比较标记化的源代码似乎是合理和可移植的。当然,它不那么强大(因为相同的函数更可能被拒绝)。
  3. 从一些符号计算教科书中借鉴的可靠启发理论上是最好的方法。对于我的目的来说,它可能看起来太重了,但它实际上可能是一个很好的选择,因为lambda函数通常很小,所以它运行得很快。


编辑

一个更复杂的例子,基于@delnan的评论:



<$

$全局字段$ b $全局变量
$全局变量
$ b s1 = sortedset(key = lambda x:x [fields [0] .lower()])
#这里有一些中间代码
#...
s2 = sortedset(key = lambda x:x [fields [0] .lower()])

s1 s2 评估为等于?

如果干预代码完全包含任何函数调用,字段的值可能会被修改,从而导致 s1 和 s2 。由于我们显然不会进行控制流分析来解决这个问题,很明显,如果我们试图在运行时执行此评估,我们必须将这两个lambda函数评估为不同的。 (即使字段不是全局的,它可能会有另一个名字绑定到它,等等)。这将严重缩减整个练习的有用性,因为很少lambda功能将不会依赖于环境。



编辑2:



我意识到这是非常重要的比较函数对象,因为它们存在于运行时。没有这些,所有依赖于外部变量的函数都不能被比较;而且大多数有用的函数都有这样的依赖关系。在运行时考虑,具有相同签名的所有函数都可以以干净,合乎逻辑的方式进行比较,无论它们依赖于什么,它们是否不纯等等。



作为结果,我不仅需要字节码,而且还需要创建函数对象时的全局状态(可能是 __ globals __ )。然后我必须将外部范围的所有变量匹配到 __ globals __ 中的值。

解决方案

编辑以检查外部状态是否会影响排序功能以及两个函数是否相同。




我破解了 dis.dis 和朋友输出到一个全局文件对象。然后,我删除了行号和标准化的变量名称(不接触常量)并比较结果。

您可以清理它,以便 dis.dis 和朋友产生结束行,所以你不必捕捉他们的输出。但是,对于使用 dis.dis 来进行函数比较以及最小的更改,这是一个工作证明概念。

 导入类型
来自操作码导入*
_have_code =(types.MethodType,types.FunctionType,types.CodeType,$ b $ types.ClassType,type)
$ b $ def dis(x):
反汇编类,方法,函数或代码

没有参数,反汇编最后的回溯
如果isinstance(x,types.InstanceType)为


x = x .__ class__ $ b $如果hasattr(x,'im_func'):
x = x。如果hasattr(x,'func_code'):
x = x.func_code
如果hasattr(x,'__dict__'):
items = x .__ dict __。items()
items.sort()
表示名称,x1表示项目:
如果isinstance(x1,_have_code):
print>> %s
尝试:
dis(x1)
除TypeError,msg:
print>>外, out,Sorry:,msg
print>> out
elif hasattr(x,'co_code'):
反汇编(x)
elif isinstance(x,str):
disassemble_string(x)
else:
raise TypeError,\
不知道如何反汇编%s对象%\
type(x).__ name__
$ b $ def defassemble(co, lasti = -1):
反汇编代码对象。
code = co.co_code
labels = findlabels(code)
linestarts = dict(findlinestarts( co))
n = len(code)
i = 0
extended_arg = 0
free = None
while i< n:
c = code [i]
op = ord(c)
如果我在linestarts中:
if i> 0:
print>> out
print>> out,%3d%linestarts [i],
else:
print>> out,'',

if i == lasti:print>> out,' - >',
else:print>> out,'',
如果我在标签中:print>> out,'>>',
else:print>> out,'',
print>>退出,repr(i).rjust(4),
print>>如果op> = HAVE_ARGUMENT:
oparg = ord(code [i])+ ord(code [i]),opname [op] .ljust(20),
i = i + 1
+1])* 256 + extended_arg
extended_arg = 0
i = i + 2
if op == EXTENDED_ARG:
extended_arg = oparg * 65536L
print>> ; out,repr(oparg).rjust(5),
if op in hasconst:
print>> out','('+ repr(co.co_consts [oparg])+')',
elif在hasname中:
print>> out','('+ co.co_names [oparg] +')',
elif op in hasjrel:
print>> out',('+ repr(i + oparg)+')',
elif在haslocal中:
print>> out','('+ co.co_varnames [oparg] +')',
elif op in hascompare:
print>> '('+ cmp_op [oparg] +')',
elf在hasfree中:
如果是free:
free = co.co_cellvars + co.co_freevars
打印>> out','('+ free [oparg] +')',
print>> out
$ b $ def defassemble_string(code,lasti = -1,varnames = None,names = None,
constants = None):
labels = findlabels(code)
n = len(代码)
i = 0
while i< n:
c = code [i]
op = ord(c)
if i == lasti:print>> out,' - >',
else:print>> out,'',
如果我在标签中:print>> out,'>>',
else:print>> out,'',
print>>退出,repr(i).rjust(4),
print>>如果op> = HAVE_ARGUMENT:
oparg = ord(code [i])+ ord(code [i]),opname [op] .ljust(15),
i = i + 1
+1])* 256
i = i + 2
print>> out,repr(oparg).rjust(5),
if hascon hasconst:
if常量:
print>> out','('+ repr(constants [oparg])+')',
else:
print>> out,'(%d)'%oparg,
elif在hasname中:
如果名称不是None:
print>> out','('+ names [oparg] +')',
else:
print>> '(%d)'%oparg,
elif op in hasjrel:
print>> out',('+ repr(i + oparg)+')',
elif op in haslocal:
if varnames:
print>> out','('+ varnames [oparg] +')',
else:
print>> out,'(%d)'%oparg,
elif op in hascompare:
print>> out','('+ cmp_op [oparg] +')',
print>> out

def findlabels(code):
检测跳转目标的字节码中的所有偏移量

返回偏移量列表


labels = []
n = len(code)
i = 0
while i< n:
c = code [i]
op = ord(c)
i = i + 1
if op> = HAVE_ARGUMENT:
oparg = ord(code [ i])+ ord(code [i + 1])* 256
i = i + 2
label = -1
如果在hasjrel中操作:
label = i + oparg
elif op in hasjabs:
label = oparg
如果标签> = 0:
如果标签不在标签中:
labels.append(label)
返回标签

def findlinestarts(code):
在字节代码中查找偏移量,这些偏移量是源代码行中的行数

生成对如在Python / compile.c中所述。


byte_increments = [ord(c)for code.co_lnotab [0 :: 2]]
line_increments = [ord(c)for code.co_lnotab [1 :: 2]]

lastlineno = None
lineno = code.co_firstlineno
addr = 0
表示byte_incr,line_incr表示zi p(byte_increments,line_increments):
if byte_incr:
if lineno!= lastlineno:
yield(addr,lineno)
lastlineno = lineno
addr + = byte_incr
lineno + = line_incr
如果lineno!= lastlineno:
yield(addr,lineno)
$ b $ class FakeFile(object):
def __init __(self) :
self.store = []
def write(self,data):
self.store.append(data)
$ b $ = lambda x:x
b = lambda x:x#True
c = lambda x:2 * x
d = lambda y:2 * y#True
e = lambda x:2 * x
f = lambda x:x * 2#True或False很好,但必须稳定
g = lambda x:2 * x
h = lambda x:x + x#True或False很好,但必须稳定

funcs = a,b,c,d,e,f,g,h

出局= []
func in funcs:
out = FakeFile()
dis(func)
outs.append(out.store)

import ast

def outf ilter(out):
for i out:
if i.strip()。isdigit():
continue $ b $ if if('in':
try :
ast.literal_eval(i)
除了ValueError:
i =(x)
收益率

processed_outs = [(out,'LOAD_GLOBAL 'out'或'LOAD_DECREF'out)
出(in''.join(outfilter(out))out out out)]

for(out1,polluted1),( out2,polluted2)in zip(processed_outs [:: 2],processed_outs [1 :: 2]):
print'Bytecode Equivalent:',out1 == out2,'\\\
Polluted by state:',polluted1 or污染2

输出为 True True False False 并且是稳定的。如果输出取决于外部状态 - 全局状态或闭包,污染布尔值为真。


I know how function comparison works in Python 3 (just comparing address in memory), and I understand why.

I also understand that "true" comparison (do functions f and g return the same result given the same arguments, for any arguments?) is practically impossible.

I am looking for something in between. I want the comparison to work on the simplest cases of identical functions, and possibly some less trivial ones:

lambda x : x == lambda x : x # True
lambda x : 2 * x == lambda y : 2 * y # True
lambda x : 2 * x == lambda x : x * 2 # True or False is fine, but must be stable
lambda x : 2 * x == lambda x : x + x # True or False is fine, but must be stable

Note that I'm interested in solving this problem for anonymous functions (lambda), but wouldn't mind if the solution also works for named functions.

The motivation for this is that inside blist module, it would be nice to verify that two sortedset instances have the same sort function before performing a union, etc. on them.

Named functions are of less interest because I can assume them to be different when they are not identical. After all, suppose someone created two sortedsets with a named function in the key argument. If they intend these instances to be "compatible" for the purposes of set operations, they'd probably use the same function, rather than two separate named functions that perform identical operations.

I can only think of three approaches. All of them seem hard, so any ideas appreciated.

  1. Comparing bytecodes might work but it might be annoying that it's implementation dependent (and hence the code that worked on one Python breaks on another).

  2. Comparing tokenized source code seems reasonable and portable. Of course, it's less powerful (since identical functions are more likely to be rejected).

  3. A solid heuristic borrowed from some symbolic computation textbook is theoretically the best approach. It might seem too heavy for my purpose, but it actually could be a good fit since lambda functions are usually tiny and so it would run fast.

EDIT

A more complicated example, based on the comment by @delnan:

# global variable
fields = ['id', 'name']

def my_function():
  global fields
  s1 = sortedset(key = lambda x : x[fields[0].lower()])
  # some intervening code here
  # ...
  s2 = sortedset(key = lambda x : x[fields[0].lower()])

Would I expect the key functions for s1 and s2 to evaluate as equal?

If the intervening code contains any function call at all, the value of fields may be modified, resulting in different key functions for s1 and s2. Since we clearly won't be doing control flow analysis to solve this problem, it's clear that we have to evaluate these two lambda functions as different, if we are trying to perform this evaluation before runtime. (Even if fields wasn't global, it might have been had another name bound to it, etc.) This would severely curtail the usefulness of this whole exercise, since few lambda functions would have no dependence on the environment.

EDIT 2:

I realized it's very important to compare the function objects as they exist in runtime. Without that, all the functions that depend on variables from outer scope cannot be compared; and most useful functions do have such dependencies. Considered in runtime, all functions with the same signature are comparable in a clean, logical way, regardless of what they depend on, whether they are impure, etc.

As a result, I need not just the bytecode but also the global state as of the time the function object was created (presumably __globals__). Then I have to match all variables from outer scope to the values from __globals__.

解决方案

Edited to check whether external state will affect the sorting function as well as if the two functions are equivalent.


I hacked up dis.dis and friends to output to a global file-like object. I then stripped out line numbers and normalized variable names (without touching constants) and compared the result.

You could clean this up so dis.dis and friends yielded out lines so you wouldn't have to trap their output. But this is a working proof-of-concept for using dis.dis for function comparison with minimal changes.

import types
from opcode import *
_have_code = (types.MethodType, types.FunctionType, types.CodeType,
              types.ClassType, type)

def dis(x):
    """Disassemble classes, methods, functions, or code.

    With no argument, disassemble the last traceback.

    """
    if isinstance(x, types.InstanceType):
        x = x.__class__
    if hasattr(x, 'im_func'):
        x = x.im_func
    if hasattr(x, 'func_code'):
        x = x.func_code
    if hasattr(x, '__dict__'):
        items = x.__dict__.items()
        items.sort()
        for name, x1 in items:
            if isinstance(x1, _have_code):
                print >> out,  "Disassembly of %s:" % name
                try:
                    dis(x1)
                except TypeError, msg:
                    print >> out,  "Sorry:", msg
                print >> out
    elif hasattr(x, 'co_code'):
        disassemble(x)
    elif isinstance(x, str):
        disassemble_string(x)
    else:
        raise TypeError, \
              "don't know how to disassemble %s objects" % \
              type(x).__name__

def disassemble(co, lasti=-1):
    """Disassemble a code object."""
    code = co.co_code
    labels = findlabels(code)
    linestarts = dict(findlinestarts(co))
    n = len(code)
    i = 0
    extended_arg = 0
    free = None
    while i < n:
        c = code[i]
        op = ord(c)
        if i in linestarts:
            if i > 0:
                print >> out
            print >> out,  "%3d" % linestarts[i],
        else:
            print >> out,  '   ',

        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(20),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == EXTENDED_ARG:
                extended_arg = oparg*65536L
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                print >> out,  '(' + repr(co.co_consts[oparg]) + ')',
            elif op in hasname:
                print >> out,  '(' + co.co_names[oparg] + ')',
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                print >> out,  '(' + co.co_varnames[oparg] + ')',
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
            elif op in hasfree:
                if free is None:
                    free = co.co_cellvars + co.co_freevars
                print >> out,  '(' + free[oparg] + ')',
        print >> out

def disassemble_string(code, lasti=-1, varnames=None, names=None,
                       constants=None):
    labels = findlabels(code)
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(15),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                if constants:
                    print >> out,  '(' + repr(constants[oparg]) + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasname:
                if names is not None:
                    print >> out,  '(' + names[oparg] + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                if varnames:
                    print >> out,  '(' + varnames[oparg] + ')',
                else:
                    print >> out,  '(%d)' % oparg,
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
        print >> out

def findlabels(code):
    """Detect all offsets in a byte code which are jump targets.

    Return the list of offsets.

    """
    labels = []
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            label = -1
            if op in hasjrel:
                label = i+oparg
            elif op in hasjabs:
                label = oparg
            if label >= 0:
                if label not in labels:
                    labels.append(label)
    return labels

def findlinestarts(code):
    """Find the offsets in a byte code which are start of lines in the source.

    Generate pairs (offset, lineno) as described in Python/compile.c.

    """
    byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
    line_increments = [ord(c) for c in code.co_lnotab[1::2]]

    lastlineno = None
    lineno = code.co_firstlineno
    addr = 0
    for byte_incr, line_incr in zip(byte_increments, line_increments):
        if byte_incr:
            if lineno != lastlineno:
                yield (addr, lineno)
                lastlineno = lineno
            addr += byte_incr
        lineno += line_incr
    if lineno != lastlineno:
        yield (addr, lineno)

class FakeFile(object):
    def __init__(self):
        self.store = []
    def write(self, data):
        self.store.append(data)

a = lambda x : x
b = lambda x : x # True
c = lambda x : 2 * x
d = lambda y : 2 * y # True
e = lambda x : 2 * x
f = lambda x : x * 2 # True or False is fine, but must be stable
g = lambda x : 2 * x
h = lambda x : x + x # True or False is fine, but must be stable

funcs = a, b, c, d, e, f, g, h

outs = []
for func in funcs:
    out = FakeFile()
    dis(func)
    outs.append(out.store)

import ast

def outfilter(out):
    for i in out:
        if i.strip().isdigit():
            continue
        if '(' in i:
            try:
                ast.literal_eval(i)
            except ValueError:
                i = "(x)"
        yield i

processed_outs = [(out, 'LOAD_GLOBAL' in out or 'LOAD_DECREF' in out)
                            for out in (''.join(outfilter(out)) for out in outs)]

for (out1, polluted1), (out2, polluted2) in zip(processed_outs[::2], processed_outs[1::2]):
    print 'Bytecode Equivalent:', out1 == out2, '\nPolluted by state:', polluted1 or polluted2

The output is True, True, False, and False and is stable. The "Polluted" bool is true if the output will depend on external state -- either global state or a closure.

这篇关于开发启发式来测试简单的匿名Python函数的等价性的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆