如何对复杂数据类型进行自动微分? [英] How to do automatic differentiation on complex datatypes?

查看:23
本文介绍了如何对复杂数据类型进行自动微分?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给出一个基于向量的非常简单的矩阵定义:

Given a very simple Matrix definition based on Vector:

import Numeric.AD
import qualified Data.Vector as V

newtype Mat a = Mat { unMat :: V.Vector a }

scale' f = Mat . V.map (*f) . unMat
add' a b = Mat $ V.zipWith (+) (unMat a) (unMat b)
sub' a b = Mat $ V.zipWith (-) (unMat a) (unMat b)
mul' a b = Mat $ V.zipWith (*) (unMat a) (unMat b)
pow' a e = Mat $ V.map (^e) (unMat a)

sumElems' :: Num a => Mat a -> a
sumElems' = V.sum . unMat

(出于演示目的......我正在使用 hmatrix,但认为问题出在那里)

(for demonstration purposes ... I am using hmatrix but thought the problem was there somehow)

还有一个误差函数(eq3):

And an error function (eq3):

eq1' :: Num a => [a] -> [Mat a] -> Mat a
eq1' as φs = foldl1 add' $ zipWith scale' as φs

eq3' :: Num a => Mat a -> [a] -> [Mat a] -> a
eq3' img as φs = negate $ sumElems' (errImg `pow'` (2::Int))
  where errImg = img `sub'` (eq1' as φs)

为什么编译器无法推导出正确的类型?

Why the compiler not able to deduce the right types in this?

diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m φs as0 = gradientDescent go as0
  where go xs = eq3' m xs φs

确切的错误信息是这样的:

The exact error message is this:

src/Stuff.hs:59:37:
    Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
    from the context (Fractional a, Ord a)
      bound by the type signature for
                 diffTest :: (Fractional a, Ord a) =>
                             Mat a -> [Mat a] -> [a] -> [[a]]
      at src/Stuff.hs:58:13-69
    or from (reflection-1.5.1.2:Data.Reflection.Reifies
               s Numeric.AD.Internal.Reverse.Tape)
      bound by a type expected by the context:
                 reflection-1.5.1.2:Data.Reflection.Reifies
                   s Numeric.AD.Internal.Reverse.Tape =>
                 [Numeric.AD.Internal.Reverse.Reverse s a]
                 -> Numeric.AD.Internal.Reverse.Reverse s a
      at src/Stuff.hs:59:21-42
      ‘a’ is a rigid type variable bound by
          the type signature for
            diffTest :: (Fractional a, Ord a) =>
                        Mat a -> [Mat a] -> [a] -> [[a]]
          at src//Stuff.hs:58:13
    Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
                   -> Numeric.AD.Internal.Reverse.Reverse s a
      Actual type: [a] -> a
    Relevant bindings include
      go :: [a] -> a (bound at src/Stuff.hs:60:9)
      as0 :: [a] (bound at src/Stuff.hs:59:15)
      φs :: [Mat a] (bound at src/Stuff.hs:59:12)
      m :: Mat a (bound at src/Stuff.hs:59:10)
      diffTest :: Mat a -> [Mat a] -> [a] -> [[a]]
        (bound at src/Stuff.hs:59:1)
    In the first argument of ‘gradientDescent’, namely ‘go’
    In the expression: gradientDescent go as0

推荐答案

gradientDescent 函数来自 ad 有类型

gradientDescent :: (Traversable f, Fractional a, Ord a) =>
                   (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
                   f a -> [f a]

它的第一个参数需要一个 f r -> 类型的函数.r 其中 rforall s.(反向 s a).go 的类型为 [a] ->a 其中 adiffTest 签名中的类型绑定.这些 a 是一样的,但是 Reverse s aa 不一样.

Its first argument requires a function of the type f r -> r where r is forall s. (Reverse s a). go has the type [a] -> a where a is the type bound in the signature of diffTest. These as are the same, but Reverse s a isn't the same as a.

反转 类型具有许多类型类的实例,可以允许我们将 a 转换为 Reverse sa 或返回.最明显的是分数a =>分数 (Reverse s a) 允许我们使用 realToFracas 转换为 Reverse s as.

The Reverse type has instances for a number of type classes that could allow us to convert an a into a Reverse s a or back. The most obvious is Fractional a => Fractional (Reverse s a) which would allow us to convert as into Reverse s as with realToFrac.

为此,我们需要能够映射一个函数 a ->bMat a 上得到一个 Mat b.最简单的方法是为 Mat 派生一个 Functor 实例.

To do so, we'll need to be able to map a function a -> b over a Mat a to obtain a Mat b. The easiest way to do this will be to derive a Functor instance for Mat.

{-# LANGUAGE DeriveFunctor #-}

newtype Mat a = Mat { unMat :: V.Vector a }
    deriving Functor

我们可以将 mfs 转换成任何 分数 a' =>;Mat a'fmap realToFrac.

We can convert the m and fs into any Fractional a' => Mat a' with fmap realToFrac.

diffTest m fs as0 = gradientDescent go as0
  where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)

但是有更好的方法隐藏在广告包中.Reverse sa 在所有 s 上通用,但 aa 中的绑定相同diffTest 的类型签名.我们真的只需要一个函数 a ->(forall s. Reverse s a).这个函数是auto 来自 Mode 类,其中 Reverse sa 有一个实例.auto 有一个稍微奇怪的类型 Mode t =>标量 t ->t类型标量(反向 s a)= a.专用于 Reverse auto 有类型

But there's a better way hiding in the ad package. The Reverse s a is universally qualified over all s but the a is the same a as the one bound in the type signature for diffTest. We really only need a function a -> (forall s. Reverse s a). This function is auto from the Mode class, for which Reverse s a has an instance. auto has the slightly wierd type Mode t => Scalar t -> t but type Scalar (Reverse s a) = a. Specialized for Reverse auto has the type

auto :: (Reifies s Tape, Num a) => a -> Reverse s a

这允许我们将 Mat as 转换为 Mat (Reverse sa)s,而无需在 Rational 之间进行转换.

This allows us to convert our Mat as into Mat (Reverse s a)s without messing around with conversions to and from Rational.

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m fs as0 = gradientDescent go as0
  where
    go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
    go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)

这篇关于如何对复杂数据类型进行自动微分?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆