为 sympy.Derivative 创建自定义打印 [英] Creating custom printing for sympy.Derivative

查看:27
本文介绍了为 sympy.Derivative 创建自定义打印的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个函数 f(x,y),我希望 f w.r.t 到 x 的偏导数显示为\partial_{x}^{n} f(x,y) 所以我创建了以下类

Say I have a function f(x,y), I want partial derivative of f w.r.t to x appear as \partial_{x}^{n} f(x,y) so I created the following class

class D(sp.Derivative):
    def _latex(self,printer=None):
        func = printer.doprint(self.args[0])
        b = self.args[1]
        if b[1] == 1 :
            return r"\partial_{%s}%s"%(printer.doprint(b[0]),func)
        else :
            return r"\partial_{%s}^{%s}%s"%(printer.doprint(b[0]),printer.doprint(b[1]),func)

工作正常,但是当我使用 doit() 方法评估导数时会返回默认行为.说我有

which works fine, but goes back to default behavior when I evaluate the derivative by using doit() method. Say I have

x,y = sp.symbols('x,y')
f = sp.Function('f')(x,y)

然后 sp.print_latex(D(f,x)) 给出 \partial_{x}f{\left(x,y \right)} 这是正确的,但 sp.print_latex(D(x*f,x).doit()) 产生 x \frac{\partial}{\partial x} f{\left(x,y\right)} + f{\left(x,y \right)},这是旧的行为.我该如何解决这个问题?

Then sp.print_latex(D(f,x)) gives \partial_{x}f{\left(x,y \right)} which is correct, but sp.print_latex(D(x*f,x).doit()) yields x \frac{\partial}{\partial x} f{\left(x,y \right)} + f{\left(x,y \right)}, which is the old behavior. How can I fix this issue?

推荐答案

问题是你没有从父类覆盖 doit 并且它返回普通的 Derivative对象而不是您的子类.我建议创建一个新的打印机类,而不是创建一个新的 Derivative 类:

The problem is that you haven't overridden doit from the parent class and it returns plain Derivative objects rather than your subclass. Rather than creating a new Derivative class I suggest to create a new printer class:

from sympy import *

from sympy.printing.latex import LatexPrinter

class MyLatexPrinter(LatexPrinter):
    def _print_Derivative(self, expr):
        differand, *(wrt_counts) = expr.args
        if len(wrt_counts) > 1 or wrt_counts[0][1] != 1:
            raise NotImplementedError('More code needed...')
        ((wrt, count),) = wrt_counts
        return '\partial_{%s} %s)' % (self._print(wrt), self._print(differand))

x, y = symbols('x, y')
f = Function('f')
expr = (x*f(x, y)).diff(x)

printer = MyLatexPrinter()

print(printer.doprint(expr))

这给 x \partial_{x} f{\left(x,y \right)}) + f{\left(x,y \right)}

您可以使用 init_printing(latex_printer=printer.doprint) 使其成为默认输出.

You can use init_printing(latex_printer=printer.doprint) to make this the default output.

这篇关于为 sympy.Derivative 创建自定义打印的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆