123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- // Package math exposes functionality from Go's math package as an elvish
- // module.
- package math
- import (
- "math"
- "math/big"
- "src.elv.sh/pkg/eval"
- "src.elv.sh/pkg/eval/errs"
- "src.elv.sh/pkg/eval/vals"
- "src.elv.sh/pkg/eval/vars"
- )
- // Ns is the namespace for the math: module.
- var Ns = eval.BuildNsNamed("math").
- AddVars(map[string]vars.Var{
- "e": vars.NewReadOnly(math.E),
- "pi": vars.NewReadOnly(math.Pi),
- }).
- AddGoFns(map[string]any{
- "abs": abs,
- "acos": math.Acos,
- "acosh": math.Acosh,
- "asin": math.Asin,
- "asinh": math.Asinh,
- "atan": math.Atan,
- "atanh": math.Atanh,
- "ceil": ceil,
- "cos": math.Cos,
- "cosh": math.Cosh,
- "floor": floor,
- "is-inf": isInf,
- "is-nan": isNaN,
- "log": math.Log,
- "log10": math.Log10,
- "log2": math.Log2,
- "max": max,
- "min": min,
- "pow": pow,
- "round": round,
- "round-to-even": roundToEven,
- "sin": math.Sin,
- "sinh": math.Sinh,
- "sqrt": math.Sqrt,
- "tan": math.Tan,
- "tanh": math.Tanh,
- "trunc": trunc,
- }).Ns()
- const (
- maxInt = int(^uint(0) >> 1)
- minInt = -maxInt - 1
- )
- var absMinInt = new(big.Int).Abs(big.NewInt(int64(minInt)))
- func abs(n vals.Num) vals.Num {
- switch n := n.(type) {
- case int:
- if n < 0 {
- if n == minInt {
- return absMinInt
- }
- return -n
- }
- return n
- case *big.Int:
- if n.Sign() < 0 {
- return new(big.Int).Abs(n)
- }
- return n
- case *big.Rat:
- if n.Sign() < 0 {
- return new(big.Rat).Abs(n)
- }
- return n
- case float64:
- return math.Abs(n)
- default:
- panic("unreachable")
- }
- }
- var (
- big1 = big.NewInt(1)
- big2 = big.NewInt(2)
- )
- func ceil(n vals.Num) vals.Num {
- return integerize(n,
- math.Ceil,
- func(n *big.Rat) *big.Int {
- q := new(big.Int).Div(n.Num(), n.Denom())
- return q.Add(q, big1)
- })
- }
- func floor(n vals.Num) vals.Num {
- return integerize(n,
- math.Floor,
- func(n *big.Rat) *big.Int {
- return new(big.Int).Div(n.Num(), n.Denom())
- })
- }
- type isInfOpts struct{ Sign int }
- func (opts *isInfOpts) SetDefaultOptions() { opts.Sign = 0 }
- func isInf(opts isInfOpts, n vals.Num) bool {
- if f, ok := n.(float64); ok {
- return math.IsInf(f, opts.Sign)
- }
- return false
- }
- func isNaN(n vals.Num) bool {
- if f, ok := n.(float64); ok {
- return math.IsNaN(f)
- }
- return false
- }
- func max(rawNums ...vals.Num) (vals.Num, error) {
- if len(rawNums) == 0 {
- return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
- }
- nums := vals.UnifyNums(rawNums, 0)
- switch nums := nums.(type) {
- case []int:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n < nums[i] {
- n = nums[i]
- }
- }
- return n, nil
- case []*big.Int:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n.Cmp(nums[i]) < 0 {
- n = nums[i]
- }
- }
- return n, nil
- case []*big.Rat:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n.Cmp(nums[i]) < 0 {
- n = nums[i]
- }
- }
- return n, nil
- case []float64:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- n = math.Max(n, nums[i])
- }
- return n, nil
- default:
- panic("unreachable")
- }
- }
- func min(rawNums ...vals.Num) (vals.Num, error) {
- if len(rawNums) == 0 {
- return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
- }
- nums := vals.UnifyNums(rawNums, 0)
- switch nums := nums.(type) {
- case []int:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n > nums[i] {
- n = nums[i]
- }
- }
- return n, nil
- case []*big.Int:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n.Cmp(nums[i]) > 0 {
- n = nums[i]
- }
- }
- return n, nil
- case []*big.Rat:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- if n.Cmp(nums[i]) > 0 {
- n = nums[i]
- }
- }
- return n, nil
- case []float64:
- n := nums[0]
- for i := 1; i < len(nums); i++ {
- n = math.Min(n, nums[i])
- }
- return n, nil
- default:
- panic("unreachable")
- }
- }
- func pow(base, exp vals.Num) vals.Num {
- if isExact(base) && isExactInt(exp) {
- // Produce exact result
- switch exp {
- case 0:
- return 1
- case 1:
- return base
- case -1:
- return new(big.Rat).Inv(vals.PromoteToBigRat(base))
- }
- exp := vals.PromoteToBigInt(exp)
- if isExactInt(base) && exp.Sign() > 0 {
- base := vals.PromoteToBigInt(base)
- return new(big.Int).Exp(base, exp, nil)
- }
- base := vals.PromoteToBigRat(base)
- if exp.Sign() < 0 {
- base = new(big.Rat).Inv(base)
- exp = new(big.Int).Neg(exp)
- }
- return new(big.Rat).SetFrac(
- new(big.Int).Exp(base.Num(), exp, nil),
- new(big.Int).Exp(base.Denom(), exp, nil))
- }
- // Produce inexact result
- basef := vals.ConvertToFloat64(base)
- expf := vals.ConvertToFloat64(exp)
- return math.Pow(basef, expf)
- }
- func isExact(n vals.Num) bool {
- switch n.(type) {
- case int, *big.Int, *big.Rat:
- return true
- default:
- return false
- }
- }
- func isExactInt(n vals.Num) bool {
- switch n.(type) {
- case int, *big.Int:
- return true
- default:
- return false
- }
- }
- func round(n vals.Num) vals.Num {
- return integerize(n,
- math.Round,
- func(n *big.Rat) *big.Int {
- q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
- m = m.Mul(m, big2)
- if m.CmpAbs(n.Denom()) < 0 {
- return q
- }
- if n.Sign() < 0 {
- return q.Sub(q, big1)
- }
- return q.Add(q, big1)
- })
- }
- func roundToEven(n vals.Num) vals.Num {
- return integerize(n,
- math.RoundToEven,
- func(n *big.Rat) *big.Int {
- q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
- m = m.Mul(m, big2)
- if diff := m.CmpAbs(n.Denom()); diff < 0 || diff == 0 && q.Bit(0) == 0 {
- return q
- }
- if n.Sign() < 0 {
- return q.Sub(q, big1)
- }
- return q.Add(q, big1)
- })
- }
- func trunc(n vals.Num) vals.Num {
- return integerize(n,
- math.Trunc,
- func(n *big.Rat) *big.Int {
- return new(big.Int).Quo(n.Num(), n.Denom())
- })
- }
- func integerize(n vals.Num, fnFloat func(float64) float64, fnRat func(*big.Rat) *big.Int) vals.Num {
- switch n := n.(type) {
- case int:
- return n
- case *big.Int:
- return n
- case *big.Rat:
- if n.Denom().IsInt64() && n.Denom().Int64() == 1 {
- // Elvish always normalizes *big.Rat with a denominator of 1 to
- // *big.Int, but we still try to be defensive here.
- return n.Num()
- }
- return fnRat(n)
- case float64:
- return fnFloat(n)
- default:
- panic("unreachable")
- }
- }
|