将 NumPy 中定义的函数转换为 SymPy [英] Convert a function defined in NumPy to SymPy

查看:39
本文介绍了将 NumPy 中定义的函数转换为 SymPy的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 numpy 中定义了一个函数,我想将其转换为 sympy,因此我可以将其应用于符号 sympy 变量.尝试将 numpy 函数直接应用于 sympy 变量失败:

I have a function defined in numpy which I would like to convert to sympy, so I can apply it to symbolic sympy variables. Trying to directly apply the numpy function to a sympy variable fails:

import numpy as np
import sympy as sp

def np_fun(a):
    return np.array([np.sin(a), np.cos(a)])

x = sp.symbols('x')
sp_fun = np_fun(x)

我收到错误

AttributeError: 'Symbol' object has no attribute 'sin'

我的下一个想法是将 numpy 函数转换为 sympy,但我找不到办法做到这一点.我知道我可以通过将函数定义为一个 sympy 表达式来使这段代码工作:

My next thought was to convert the numpy function to sympy, but I couldn't find a way to do that. I know I could make this code work by just defining the function as a sympy expression:

sp_fun = sp.Array([sp.sin(x), sp.cos(x)])

但我使用正弦/余弦函数作为一个简单的例子.我实际使用的函数已经在numpy中定义过了,而且要复杂得多,重写起来会很繁琐.

But I'm using the sine/cosine function as a simple example. The actual function I'm using has already been defined in numpy, and is much more complicated, so it would be very tedious to rewrite it.

推荐答案

原则上,您可以直接修改函数的 ast(抽象语法树"),但在实践中可能会变得很麻烦.无论如何,以下是为您的简单示例执行此操作的方法:

In principle, you could directly modify the ast ("abstract syntax tree") of the function, though in practice it might get quite hairy. Anyway, here is how to do it for your simple example:

这从源创建一个 ast 并从 NodeTransformer 类派生以就地修改 ast.节点转换器有一个通用的访问方法,它遍历一个节点及其子树,在派生类中委托给节点特定的访问者.在这里,我们将所有名称 np 更改为 sp,然后将这些属性更改为以前的 np 现在 sp 拼写不同.您必须将所有这些差异添加到 translate 字典中.

This creates from the source an ast and derives from the NodeTransformer class to modify the ast in-place. The node transformer has a generic visit method that traverses a node and its subtree delegating to node specific visitors in derived classes. Here we change all names np to sp and afterwards change those attributes to former np now sp that spell differently. You'd have to add all such differences to the translate dict.

最后,我们从 ast 编译回代码对象并执行它以使修改后的函数可用.

Finally, we compile back from the ast to a code object and execute it to make the modified function available.

import ast, inspect
import numpy as np
import sympy as sp

def f(a):
    return np.array([np.sin(a), np.cos(a)])

z = ast.parse(inspect.getsource(f))

translate = {'array': 'Array'}

class np_to_sp(ast.NodeTransformer):
    def visit_Name(self, node):
        if node.id=='np':
            node = ast.copy_location(ast.Name(id='sp', ctx=node.ctx), node)
        return node
    def visit_Attribute(self, node):
        self.generic_visit(node)
        if node.value.id=='sp' and node.attr in translate:
            fields = {k: getattr(node, k) for k in node._fields}
            fields['attr'] = translate[node.attr]
            node = ast.copy_location(ast.Attribute(**fields), node)
        return node

np_to_sp().visit(z)

exec(compile(z, '', 'exec'))

x = sp.Symbol('x')
print(f(x))

输出:

[sin(x), cos(x)]

UPDATE简单增强:修改函数调用的函数:

import ast, inspect
import numpy as np
import sympy as sp

def f(a):
    return np.array([np.sin(a), np.cos(a)])

def f2(a):
    return np.array([1, np.sin(a)])

def f3(a):
    return f(a) + f2(a)

translate = {'array': 'Array'}

class np_to_sp(ast.NodeTransformer):
    def visit_Name(self, node):
        if node.id=='np':
            node = ast.copy_location(ast.Name(id='sp', ctx=node.ctx), node)
        return node
    def visit_Attribute(self, node):
        self.generic_visit(node)
        if node.value.id=='sp' and node.attr in translate:
            fields = {k: getattr(node, k) for k in node._fields}
            fields['attr'] = translate[node.attr]
            node = ast.copy_location(ast.Attribute(**fields), node)
        return node

from types import FunctionType

for fn in f3.__code__.co_names:
    fo = globals()[fn]
    if not isinstance(fo, FunctionType):
        continue
    z = ast.parse(inspect.getsource(fo))
    np_to_sp().visit(z)
    exec(compile(z, '', 'exec'))

x = sp.Symbol('x')
print(f3(x))

打印:

[sin(x) + 1, sin(x) + cos(x)]

这篇关于将 NumPy 中定义的函数转换为 SymPy的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆