Library Interval.Float.Primitive_ops

From Coq Require Import ZArith Reals.
Require Import Int63Compat.
From Coq Require Import Floats Psatz.
From Flocq Require Import Zaux Raux BinarySingleNaN PrimFloat Sterbenz Mult_error.

Module Import Compat.
Definition ldexp f (_ : Z) : float := f.
Definition frexp (f : float) := (f, Z0).
End Compat.
Import FloatOps.
Module Import Z.
Notation ldexp := ldexp.
Notation frexp := frexp.
End Z.
Import Floats.
Import Zaux BinarySingleNaN.

Require Import Missing.Stdlib Missing.Flocq.
Require Import Xreal.
Require Import Basic.
Require Import Sig.
Require Generic_proof.

Module PrimitiveFloat <: FloatOps.

Definition radix := radix2.
Definition sensible_format := true.

Definition type := PrimFloat.float.

Definition toF x : float radix2 :=
  match Prim2SF x with
  | S754_zero _Fzero
  | S754_infinity _ | S754_nanBasic.Fnan
  | S754_finite s m eBasic.Float s m e
  end.

Definition precision := Z.
Definition sfactor := Z. Definition prec p := match p with Zpos qq | _xH end.
Definition PtoP p := Zpos p.
Definition ZtoS (x : Z) := x.
Definition StoZ (x : Z) := x.
Definition incr_prec p i := Zplus p (Zpos i).

Definition zero := zero.
Definition nan := nan.

Definition fromZ x :=
  match x with
  | Z0zero
  | Zpos x
    match (x ?= 9007199254740992)%positive with
    | Ltof_int63 (Int63.of_pos x)
    | _nan
    end
  | Zneg x
    match (x ?= 9007199254740992)%positive with
    | Lt ⇒ (-(of_int63 (Int63.of_pos x)))%float
    | _nan
    end
  end.

Definition fromZ_UP (p : precision) x :=
  match x with
  | Z0zero
  | Zpos x
    match (x ?= 9007199254740992)%positive with
    | Ltof_int63 (Int63.of_pos x)
    | _
      let x := Zpos x in
      let d := Z.log2 x in
      let e := (d - 52)%Z in
      let m := Z.shiftr x e in
      Z.ldexp (of_int63 (of_Z m + 1)) e
    end
  | Zneg x
    match (x ?= 9007199254740992)%positive with
    | Lt ⇒ (-(of_int63 (Int63.of_pos x)))%float
    | _
      let x := Zpos x in
      let d := Z.log2 x in
      let e := (d - 52)%Z in
      let m := Z.shiftr x e in
      next_up (Z.ldexp (-(of_int63 (of_Z m))) e)
    end
  end.

Definition fromZ_DN (p : precision) x :=
  match x with
  | Z0zero
  | Zpos x
    match (x ?= 9007199254740992)%positive with
    | Ltof_int63 (Int63.of_pos x)
    | _
      let x := Zpos x in
      let d := Z.log2 x in
      let e := (d - 52)%Z in
      let m := Z.shiftr x e in
      next_down (Z.ldexp (of_int63 (of_Z m)) e)
    end
  | Zneg x
    match (x ?= 9007199254740992)%positive with
    | Lt ⇒ (-(of_int63 (Int63.of_pos x)))%float
    | _
      let x := Zpos x in
      let d := Z.log2 x in
      let e := (d - 52)%Z in
      let m := Z.shiftr x e in
      Z.ldexp (-(of_int63 (Int63.of_Z m + 1))) e
    end
  end.

Definition fromF (f : float radix) :=
  match f with
  | Basic.Fnannan
  | Basic.Fzerozero
  | Basic.Float s m e
    if ((e <=? 971)%Z && (-1074 <=? e)%Z
        && (Pos.size m <=? 53)%positive)%bool then
      let m := of_int63 (Int63.of_pos m) in
      let e := Int63.of_Z (e + FloatOps.shift) in
      let f := ldshiftexp m e in
      if s then (- f)%float else f
    else nan
  end.

Definition classify x :=
  match classify x with
  | NaNSig.Fnan
  | PInfFpinfty
  | NInfFminfty
  | _Freal
  end.

Definition real x :=
  match PrimFloat.classify x with
  | PInf | NInf | NaNfalse
  | _true
  end.

Definition is_nan x :=
  match PrimFloat.classify x with
  | NaNtrue
  | _false
  end.

Definition mag x :=
  let (_, e) := PrimFloat.frshiftexp x in
  (Int63.to_Z e - FloatOps.shift)%Z.

Definition valid_ub x := negb (PrimFloat.eqb x neg_infinity).

Definition valid_lb x := negb (PrimFloat.eqb x infinity).

Definition Xcomparison_of_float_comparison c :=
  match c with
  | FEqXeq
  | FLtXlt
  | FGtXgt
  | FNotComparableXund
  end.

Definition cmp x y := Xcomparison_of_float_comparison (compare x y).

Definition min x y :=
  match (x ?= y)%float with
  | FEq | FLtx
  | FGty
  | FNotComparablenan
  end.

Definition max x y :=
  match (x ?= y)%float with
  | FEq | FGtx
  | FLty
  | FNotComparablenan
  end.

Definition neg x := (- x)%float.

Definition abs x := abs x.

Definition scale x e :=
  ldshiftexp x (Int63.of_Z e + Int63.of_Z FloatOps.shift)%int63.

Definition pow2_UP (_ : precision) e :=
  if Zle_bool emax e then infinity else scale (fromZ 1) (Z.max e (-1074)).

Definition div2 x := (x / 2)%float.

Definition add_UP (_ : precision) x y := next_up (x + y).

Definition add_DN (_ : precision) x y := next_down (x + y).

Definition sub_UP (_ : precision) x y := next_up (x - y).

Definition sub_DN (_ : precision) x y := next_down (x - y).

Definition mul_UP (_ : precision) x y := next_up (x × y).

Definition mul_DN (_ : precision) x y := next_down (x × y).

Definition div_UP (_ : precision) x y := next_up (x / y).

Definition div_DN (_ : precision) x y := next_down (x / y).

Definition sqrt_UP (_ : precision) x := next_up (PrimFloat.sqrt x).

Definition sqrt_DN (_ : precision) x := next_down (PrimFloat.sqrt x).

Definition nearbyint default (mode : rounding_mode) (f : type) :=
  if real f then
    let '(f', e) := frshiftexp f in
    if Int63.leb (of_Z (FloatOps.prec + FloatOps.shift))%int63 e then f else
      let m := normfr_mantissa f' in
      let d := (of_Z (FloatOps.prec + FloatOps.shift) - e)%int63 in
      let mh := (m >> d)%int63 in
      match mode with
      | rnd_ZRif get_sign f then (- (of_int63 mh))%float else of_int63 mh
      | rnd_DN
        if get_sign f then
          let f'' := (- (of_int63 mh))%float in
          if PrimFloat.ltb f f'' then (- (of_int63 (mh + 1)))%float else f''
        else
          of_int63 mh
      | rnd_UP
        if get_sign f then
          PrimFloat.opp (of_int63 mh)
        else
          let f'' := of_int63 mh in
          if PrimFloat.ltb f'' f then of_int63 (mh + 1) else f''
      | rnd_NE
        let fl := of_int63 mh in
        let f' :=
            match (abs f - fl ?= 0.5)%float with
            | FLtfl
            | FGtof_int63 (mh + 1)
            | FEq | FNotComparable
                if Int63.eqb (mh land 1) 0 then fl else of_int63 (mh + 1)
            end in
        if get_sign f then (- f')%float else f'
      end
  else default.

Definition nearbyint_UP := nearbyint infinity.

Definition nearbyint_DN := nearbyint neg_infinity.

Definition midpoint (x y : type) :=
  let z := ((x + y) / 2)%float in
  if is_infinity z then (x / 2 + y / 2)%float else z.

Definition toX x := FtoX (toF x).
Definition toR x := proj_val (toX x).
Definition convert x := FtoX (toF x).

Lemma ZtoS_correct:
   prec z,
  (z StoZ (ZtoS z))%Z toX (pow2_UP prec (ZtoS z)) = Xnan.

Lemma zero_correct : toX zero = Xreal 0.

Lemma nan_correct : classify nan = Sig.Fnan.

Definition BtoX (x : binary_float FloatOps.prec emax) :=
  match x with
  | B754_zero _Xreal 0
  | B754_finite s m e _Xreal (FtoR radix2 s m e)
  | _Xnan
  end.

Lemma BtoX_B2R x r : BtoX x = Xreal r r = B2R x.

Lemma B2R_BtoX : x, is_finite x = true BtoX x = Xreal (B2R x).

Lemma toX_Prim2B x : toX x = BtoX (Prim2B x).

Lemma BtoX_Bopp x : BtoX (Bopp x) = (- (BtoX x))%XR.

Lemma valid_lb_correct :
   f, valid_lb f = match classify f with Fpinftyfalse | _true end.

Lemma valid_ub_correct :
   f, valid_ub f = match classify f with Fminftyfalse | _true end.

Lemma classify_correct :
   f, real f = match classify f with Frealtrue | _false end.

Lemma real_correct :
   f, real f = match toX f with Xnanfalse | _true end.

Lemma is_nan_correct :
   f, is_nan f = match classify f with Sig.Fnantrue | _false end.

Lemma real_is_finite x : real (B2Prim x) = is_finite x.

Local Existing Instance Hprec.
Local Existing Instance Hmax.

Lemma of_int63_exact i :
  (Int63.to_Z i 2^53)%Z
  toX (of_int63 i) = Xreal (IZR (Int63.to_Z i)).

Lemma of_int63_of_pos_exact p :
  (p < 2^53)%positive
  toX (of_int63 (Int63.of_pos p)) = Xreal (IZR (Zpos p)).

Lemma toX_neg x : toX (- x) = (- (toX x))%XR.

Lemma fromZ_correct :
   n,
  (Z.abs n 256)%Z toX (fromZ n) = Xreal (IZR n).

Lemma mag_correct :
   f, (Rabs (toR f) < bpow radix2 (StoZ (mag f)))%R.

Lemma valid_ub_next_up x : valid_ub (next_up x) = true.

Lemma valid_lb_next_down x : valid_lb (next_down x) = true.

Lemma shiftr_pos p :
  let d := Z.log2 (Z.pos p) in
  let s := Z.shiftr (Z.pos p) (d - 52) in
  (0 d - 52
   (s × 2 ^ (d - 52) Z.pos p < (s + 1) × 2 ^ (d - 52)
     s < 2^53))%Z.

Lemma Bsign_pos x r : BtoX x = Xreal r (0 < r)%R Bsign x = false.

Lemma fromZ_UP_correct :
   p n,
  valid_ub (fromZ_UP p n) = true le_upper (Xreal (IZR n)) (toX (fromZ_UP p n)).

Lemma fromZ_DN_correct :
   p n,
  valid_lb (fromZ_DN p n) = true le_lower (toX (fromZ_DN p n)) (Xreal (IZR n)).

Lemma cmp_correct :
   x y,
  cmp x y =
  match classify x, classify y with
  | Sig.Fnan, _ | _, Sig.FnanXund
  | Fminfty, FminftyXeq
  | Fminfty, _Xlt
  | _, FminftyXgt
  | Fpinfty, FpinftyXeq
  | _, FpinftyXlt
  | Fpinfty, _Xgt
  | Freal, FrealXcmp (toX x) (toX y)
  end.

Definition float_comparison_of_Xcomparison c :=
  match c with
  | XeqFEq
  | XltFLt
  | XgtFGt
  | XundFNotComparable
  end.

Lemma compare_cmp x y : compare x y = float_comparison_of_Xcomparison (cmp x y).

Lemma min_correct :
   x y,
  match classify x, classify y with
  | Sig.Fnan, _ | _, Sig.Fnanclassify (min x y) = Sig.Fnan
  | Fminfty, _ | _, Fminftyclassify (min x y) = Fminfty
  | Fpinfty, _min x y = y
  | _, Fpinftymin x y = x
  | Freal, FrealtoX (min x y) = Xmin (toX x) (toX y)
  end.

Lemma Rmax_compare x y :
  Rmax x y = match Rcompare x y with Lty | _x end.

Lemma max_correct :
   x y,
  match classify x, classify y with
  | Sig.Fnan, _ | _, Sig.Fnanclassify (max x y) = Sig.Fnan
  | Fpinfty, _ | _, Fpinftyclassify (max x y) = Fpinfty
  | Fminfty, _max x y = y
  | _, Fminftymax x y = x
  | Freal, FrealtoX (max x y) = Xmax (toX x) (toX y)
  end.

Lemma neg_correct :
   x,
  match classify x with
  | FrealtoX (neg x) = Xneg (toX x)
  | Sig.Fnanclassify (neg x) = Sig.Fnan
  | Fminftyclassify (neg x) = Fpinfty
  | Fpinftyclassify (neg x) = Fminfty
  end.

Lemma abs_correct :
   x, toX (abs x) = Xabs (toX x) (valid_ub (abs x) = true).

Local Existing Instance PrimFloat.Hprec.
Local Existing Instance PrimFloat.Hmax.

Lemma Bdiv2_correct x :
  is_finite x = true
  let x2 := Bdiv mode_NE x (Prim2B 2) in
  B2R x2 =
    Generic_fmt.round radix2
      (FLT.FLT_exp (3 - emax - FloatOps.prec) FloatOps.prec)
      (round_mode mode_NE)
      (B2R x / 2)
   is_finite x2 = true
   Bsign x2 = Bsign x
   (Rabs (B2R x2) Rabs (B2R x))%R.

Lemma div2_correct :
   x, sensible_format = true
  (1 / 256 Rabs (toR x))%R
  toX (div2 x) = Xdiv (toX x) (Xreal 2).

Lemma le_upper_succ_finite s m e B :
  le_upper (@FtoX radix2 (Basic.Float s m e))
    (@FtoX radix2
       match B2SF (Bsucc (B754_finite s m e B)) with
       | S754_zero _Fzero
       | S754_finite s m eBasic.Float s m e
       | _Basic.Fnan
       end).

Lemma add_UP_correct :
   p x y, valid_ub x = true valid_ub y = true
     (valid_ub (add_UP p x y) = true
        le_upper (Xadd (toX x) (toX y)) (toX (add_UP p x y))).

Lemma le_lower_pred_finite s m e B :
  le_lower
    (@FtoX radix2
       match B2SF (Bpred (B754_finite s m e B)) with
       | S754_zero _Fzero
       | S754_finite s m eBasic.Float s m e
       | _Basic.Fnan
       end)
    (@FtoX radix2 (Basic.Float s m e)).

Lemma add_DN_correct :
   p x y, valid_lb x = true valid_lb y = true
     (valid_lb (add_DN p x y) = true
        le_lower (toX (add_DN p x y)) (Xadd (toX x) (toX y))).

Lemma sub_UP_correct :
   p x y, valid_ub x = true valid_lb y = true
     (valid_ub (sub_UP p x y) = true
        le_upper (Xsub (toX x) (toX y)) (toX (sub_UP p x y))).

Lemma sub_DN_correct :
   p x y, valid_lb x = true valid_ub y = true
     (valid_lb (sub_DN p x y) = true
        le_lower (toX (sub_DN p x y)) (Xsub (toX x) (toX y))).

Definition is_non_neg x :=
  valid_ub x = true
   match toX x with XnanTrue | Xreal r ⇒ (0 r)%R end.

Definition is_non_neg' x :=
  match toX x with Xnanvalid_ub x = true | Xreal r ⇒ (0 r)%R end.

Definition is_pos x :=
  valid_ub x = true
   match toX x with XnanTrue | Xreal r ⇒ (0 < r)%R end.

Definition is_non_pos x :=
  valid_lb x = true
   match toX x with XnanTrue | Xreal r ⇒ (r 0)%R end.

Definition is_non_pos' x :=
  match toX x with Xnanvalid_lb x = true | Xreal r ⇒ (r 0)%R end.

Definition is_neg x :=
  valid_lb x = true
   match toX x with XnanTrue | Xreal r ⇒ (r < 0)%R end.

Definition is_non_neg_real x :=
  match toX x with XnanFalse | Xreal r ⇒ (0 r)%R end.

Definition is_pos_real x :=
  match toX x with XnanFalse | Xreal r ⇒ (0 < r)%R end.

Definition is_non_pos_real x :=
  match toX x with XnanFalse | Xreal r ⇒ (r 0)%R end.

Definition is_neg_real x :=
  match toX x with XnanFalse | Xreal r ⇒ (r < 0)%R end.

Lemma mul_UP_correct :
   p x y,
  ((is_non_neg' x is_non_neg' y)
   (is_non_pos' x is_non_pos' y)
   (is_non_pos_real x is_non_neg_real y)
   (is_non_neg_real x is_non_pos_real y))
  valid_ub (mul_UP p x y) = true
  le_upper (Xmul (toX x) (toX y)) (toX (mul_UP p x y)).

Lemma mul_DN_correct :
   p x y,
  ((is_non_neg_real x is_non_neg_real y)
   (is_non_pos_real x is_non_pos_real y)
   (is_non_neg' x is_non_pos' y)
   (is_non_pos' x is_non_neg' y))
  (valid_lb (mul_DN p x y) = true
  le_lower (toX (mul_DN p x y)) (Xmul (toX x) (toX y))).

Lemma pow2_UP_correct :
   p s, (valid_ub (pow2_UP p s) = true
              le_upper (Xscale radix2 (Xreal 1) (StoZ s)) (toX (pow2_UP p s))).

Definition is_real_ub x :=
  match toX x with Xnanvalid_ub x = true | _True end.

Definition is_real_lb x :=
  match toX x with Xnanvalid_lb x = true | _True end.

Lemma div_UP_correct :
   p x y,
  ((is_real_ub x is_pos_real y)
   (is_real_lb x is_neg_real y))
  valid_ub (div_UP p x y) = true
  le_upper (Xdiv (toX x) (toX y)) (toX (div_UP p x y)).

Lemma div_DN_correct :
   p x y,
  ((is_real_ub x is_neg_real y)
   (is_real_lb x is_pos_real y))
  valid_lb (div_DN p x y) = true
  le_lower (toX (div_DN p x y)) (Xdiv (toX x) (toX y)).

Lemma sqrt_UP_correct :
   p x,
  valid_ub (sqrt_UP p x) = true
   le_upper (Xsqrt (toX x)) (toX (sqrt_UP p x)).

Lemma sqrt_DN_correct :
   p x,
    valid_lb x = true
     (valid_lb (sqrt_DN p x) = true
         le_lower (toX (sqrt_DN p x)) (Xsqrt (toX x))).

Lemma Bnormfr_mantissa_correct :
   f : binary_float FloatOps.prec emax,
  (/ 2 Rabs (B2R f) < 1)%R
  match f with
  | B754_finite _ m e _
    Bnormfr_mantissa f = N.pos m
     Z.pos (digits2_pos m) = FloatOps.prec (e = - FloatOps.prec)%Z
  | _False
  end.

Lemma nearbyint_correct :
   default mode x,
  real x = true
  Xnearbyint mode (toX x) = toX (nearbyint default mode x).

Lemma nearbyint_UP_correct :
   mode x,
  valid_ub (nearbyint_UP mode x) = true
   le_upper (Xnearbyint mode (toX x)) (toX (nearbyint_UP mode x)).

Lemma nearbyint_DN_correct :
   mode x,
  valid_lb (nearbyint_DN mode x) = true
   le_lower (toX (nearbyint_DN mode x)) (Xnearbyint mode (toX x)).

Lemma midpoint_correct :
   x y,
  sensible_format = true
  real x = true real y = true (toR x toR y)%R
   real (midpoint x y) = true (toR x toR (midpoint x y) toR y)%R.

End PrimitiveFloat.