(** Code to implement differentation on arbitrary types. diff.ml: Library to compute exact derivatives of arbitrary mathematical functions. Copyright (C) 2006 Will M. Farr This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. *) (** Input signature for [Make].*) module type In = sig type t (** The multiplicative identity *) val one : t (** The additive identity *) val zero : t (** Comparisons *) val compare : t -> t -> int val (+) : t -> t -> t val (-) : t -> t -> t val ( * ) : t -> t -> t val (/) : t -> t -> t (** Arithmetic operations *) val sin : t -> t val cos : t -> t val tan : t -> t (** Trigonometric functions *) val asin : t -> t val acos : t -> t val atan : t -> t (** Inverse trigonometric functions *) val sqrt : t -> t val log : t -> t val exp : t -> t val ( ** ) : t -> t -> t (** Powers, exponents and logs. *) val print_t : out_channel -> t -> unit end (** Output signature for [Make]. *) module type Out = sig (** Input from [In] *) type t (** Opaque tags for variables being differentiated. *) type tag (** Output diff type. A [diff] is either a constant [t] value [C(x)] or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be [diff]. [D(tag, x, dx)] represents a small increment [dx] in the value labeled by [tag] about the value [x]. *) type diff = | C of t | D of tag * diff * diff val zero : diff val one : diff (** Additive and multiplicative identities in [diff]. *) (** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to compute the derivative. For example, if we had already defined [(/) : diff -> diff -> diff], and the input signature didn't provide [log], we could define it using [let log = lift log ((/) one)]. *) val lift : (t -> t) -> (diff -> diff) -> (diff -> diff) (** [lower f] takes a function defined on [diff] to the equivalent one defined on [t]. It is an error if [f (C x)] does not evaluate to [C y] for some [y]. The function is called [lower] because it is the inverse of the function [lift] which lifts [f : t -> t] to [diff -> diff]. *) val lower : (diff -> diff) -> (t -> t) (** [lower_multi f] lowers [f] from [diff array -> diff] to [t array -> t]. It is an error if [f \[|(C x0); ...; (C xn)|\]] does not evaluate to [\[|(C y0); ...; (C yn)|\]].*) val lower_multi : (diff array -> diff) -> (t array -> t) (** [lower_multi_multi f] lowers [f] from [diff array -> diff array] to [t array -> t array]. *) val lower_multi_multi : (diff array -> diff array) -> (t array -> t array) (** [lower_jacobian j] lowers the jacobian [j] from [diff array -> diff array array] to [t array -> t array array]. *) val lower_jacobian : (diff array -> diff array array) -> (t array -> t array array) (** [d f] returns the function which computes the derivative of [f]. *) val d : (diff -> diff) -> (diff -> diff) (** [partial i f] returns the function which computes the derivative of [f] with respect to its [i]th argument (counting from 0). *) val partial : int -> (diff array -> diff) -> (diff array -> diff) (** [jacobian f] returns the jacobian matrix with elements [m.(i).(j) = d f.(i)/d x.(j)]. *) val jacobian : (diff array -> diff array) -> (diff array -> diff array array) (** [compare d1 d2] compares only the values stored in the derivative---that is either the [C x] or [D(_, x, _)]. [(<)], [(>)], ... are defined in terms of compare, which is defined in terms of [compare] from [In]. *) val compare : diff -> diff -> int val (<) : diff -> diff -> bool val (<=) : diff -> diff -> bool val (>) : diff -> diff -> bool val (>=) : diff -> diff -> bool val (=) : diff -> diff -> bool val (<>) : diff -> diff -> bool (** Comparison functions *) val (+) : diff -> diff -> diff val (-) : diff -> diff -> diff val ( * ) : diff -> diff -> diff val (/) : diff -> diff -> diff (** Algebra. *) val cos : diff -> diff val sin : diff -> diff val tan : diff -> diff (** Trig *) val acos : diff -> diff val asin : diff -> diff val atan : diff -> diff (** Inverse trig *) val sqrt : diff -> diff val log : diff -> diff val exp : diff -> diff val ( ** ) : diff -> diff -> diff (** Powers, exponents and logs *) val print_diff : out_channel -> diff -> unit end module Make(I : In) : Out with type t = I.t = struct type t = I.t type tag = int (* Terms are either constant or of the form (x + dx), with x represented by tag. We are careful to maintain the invariant that all the tags in x and dx are larger than the tag of (x + dx). (That is, you can think of a D(tag, x, dx) as a tree : tag / \ x dx which satisfies the heap property tag < tagx and tag < tagdx. *) type diff = | C of t | D of tag * diff * diff let rec print_diff out = function | C(x) -> Printf.fprintf out "C("; I.print_t out x; Printf.fprintf out ")" | D(tag, x, dx) -> Printf.fprintf out "D(%d," tag; print_diff out x; Printf.fprintf out ", "; print_diff out dx; Printf.fprintf out ")" (* Additive and multiplicative identities in derivatives. *) let zero = C(I.zero) let one = C(I.one) (* Unique tags *) let new_tag = let count = ref 0 in fun () -> count := !count + 1; !count (* Have to define the arithmetic operators first because they are used in [lift] and friends. To maintain the heap property of the tags, we select the smallest of tagx and tagy when we're operating on two D(_,_,_) objects. We know that we can directly construct D(smallest, _, _), where _ and _ can be any of the sub derivatives. But, we can only use the sub-derivatives from larger in any direct construction D(larger, _, _); there is no guarantee that the tags in the sub-derivatives from D(smaller, _, _) are in any relation to larger. This is the reason for the somewhat obfuscated code below. *) let rec (+) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(+) x y) | C(x), D(tag, y, dy) -> D(tag, d1 + y, dy) | D(tag, x, dx), C(y) -> D(tag, x + d2, dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x + y, dx + dy) | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> D(tagx, x + d2, dx) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagy, d1 + y, dy) let rec (-) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(-) x y) | C(x), D(tag, y, dy) -> D(tag, d1 - y, zero - dy) | D(tag, x, dx), C(y) -> D(tag, x - d2, dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x - y, dx - dy) | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> D(tagx, x - d2, dx) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagy, d1 - y, zero - dy) let rec ( * ) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.( * ) x y) | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy) | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x * y, x*dy + dx*y) | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> D(tagx, x*d2, dx*d2) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagy, d1*y, d1*dy) let rec (/) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(/) x y) | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2) | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y)) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x/y, dx/y - x*dy/(y*y)) | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> let y2 = y * y in let mdyoy2 = zero - dy/y2 and ooy = one/y in D(tagx, x * (D(tagy, ooy, mdyoy2)), dx*(D(tagy, ooy, mdyoy2))) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagy, d1/y, zero - d1*dy/(y*y)) let lift f f' = let rec lf = function | C(x) -> C(f x) | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in lf (* Now that we have the algebra of derivatives worked out, we can define lift, lift2, lower and lower2. *) let lower f = fun x -> match f (C x) with | C(y) -> y | _ -> raise (Failure "lower expects numerical output") let rec extract_deriv tag = function | C(_) -> zero | D(tagx, x, dx) when tagx = tag -> dx | D(tagx, x, dx) -> D(tagx, extract_deriv tag x, extract_deriv tag dx) and drop_deriv tag = function | D(tagx, x, dx) when tagx = tag -> x | D(tagx, x, dx) -> D(tagx, drop_deriv tag x, drop_deriv tag dx) | x -> x (* Only matches C(_) *) let d f = fun x -> let tag = new_tag () in let res = f (D(tag, x, one)) in extract_deriv tag res let rec compare d1 d2 = match d1, d2 with | C(x), C(y) -> I.compare x y | C(x), D(_, y, _) -> compare d1 y | D(_, x, _), C(y) -> compare x d2 | D(_, x, _), D(_, y, _) -> compare x y let rec cos = function | C(x) -> C(I.cos x) | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and sin = function | C(x) -> C(I.sin x) | D(tag, x, dx) -> D(tag, sin x, dx * (cos x)) let rec tan = function | C(x) -> C(I.tan x) | D(tag, x, dx) -> let tx = tan x in D(tag, tx, (one + tx*tx)*dx) let rec sqrt = function | C(x) -> C(I.sqrt x) | D(tag, x, dx) -> let sx = sqrt x in D(tag, sx, dx / (sx + sx)) let rec acos = function | C(x) -> C(I.acos x) | D(tag, x, dx) -> D(tag, acos x, zero - dx / (sqrt (one - x * x))) let rec asin = function | C(x) -> C(I.asin x) | D(tag, x, dx) -> D(tag, asin x, dx / (sqrt (one - x * x))) let rec atan = function | C(x) -> C(I.atan x) | D(tag, x, dx) -> D(tag, atan x, dx / (one + x*x)) let rec log = function | C(x) -> C(I.log x) | D(tag, x, dx) -> D(tag, log x, dx/x) let rec exp = function | C(x) -> C(I.exp x) | D(tag, x, dx) -> let ex = exp x in D(tag, ex, dx*ex) let rec ( ** ) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.( ** ) x y) | C(x), D(tag, y, dy) -> D(tag, d1**y, d1**y * (log d1) * dy) | D(tag, x, dx), C(y) -> D(tag, x**d2, x**(d2 - one)*d2*dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x**y, x**(y - one) * (dx*y + x*(log x)*dy)) | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> D(tagx, x**y*(one + (log x)*(D(tagy, zero, dy))), dx*x**(y - one)*(y + (one + y*(log x))*(D(tagy, zero, dy)))) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagy, x**y + y*x**(y-one)*(D(tagx, zero, dx)), dy*(x**y*(log x) + x**(y-one)*(one+y*(log x))*(D(tagx, zero, dx)))) let replace arr i x = Array.mapi (fun j y -> if j <> i then y else x) arr let partial i f = fun args -> let x = args.(i) in let one_d_f x = f (replace args i x) in (d one_d_f) x let jacobian f args = let n = Array.length args in let tags = Array.init n (fun _ -> new_tag ()) in let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in let result = f dargs in Array.init n (fun i -> let fi = result.(i) in Array.init n (fun j -> let tj = tags.(j) in Array.fold_left (fun res tag -> if tag = tj then extract_deriv tag res else drop_deriv tag res) fi tags)) let c_ify arr = Array.map (fun x -> (C x)) arr let de_c_ify arr = Array.map (function | C(x) -> x | d -> raise (Failure "cannot lower [|D(...); ...; D(...)|]")) arr let lower_multi f = fun args -> match (f (c_ify args)) with | C(x) -> x | _ -> raise (Failure "cannot lower D(...)") let lower_multi_multi f = fun args -> de_c_ify (f (c_ify args)) let lower_jacobian j args = Array.map (fun carr -> de_c_ify carr) (j (c_ify args)) (* Define these all at once (and last) so as not to spoil the comparison operator namespace. *) let (<) d1 d2 = compare d1 d2 < 0 and (<=) d1 d2 = not (compare d1 d2 > 0) and (>) d1 d2 = compare d1 d2 > 0 and (>=) d1 d2 = not (compare d1 d2 < 0) and (=) d1 d2 = compare d1 d2 = 0 and (<>) d1 d2 = not (compare d1 d2 = 0) end module DFloat = Make(struct type t = float include Pervasives let (+) = (+.) let (-) = (-.) let (/) = (/.) let ( * ) = ( *. ) let one = 1.0 let zero = 0.0 let print_t out = Printf.fprintf out "%g" end)