math.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Package math exposes functionality from Go's math package as an elvish
  2. // module.
  3. package math
  4. import (
  5. "math"
  6. "math/big"
  7. "src.elv.sh/pkg/eval"
  8. "src.elv.sh/pkg/eval/errs"
  9. "src.elv.sh/pkg/eval/vals"
  10. "src.elv.sh/pkg/eval/vars"
  11. )
  12. // Ns is the namespace for the math: module.
  13. var Ns = eval.BuildNsNamed("math").
  14. AddVars(map[string]vars.Var{
  15. "e": vars.NewReadOnly(math.E),
  16. "pi": vars.NewReadOnly(math.Pi),
  17. }).
  18. AddGoFns(map[string]any{
  19. "abs": abs,
  20. "acos": math.Acos,
  21. "acosh": math.Acosh,
  22. "asin": math.Asin,
  23. "asinh": math.Asinh,
  24. "atan": math.Atan,
  25. "atanh": math.Atanh,
  26. "ceil": ceil,
  27. "cos": math.Cos,
  28. "cosh": math.Cosh,
  29. "floor": floor,
  30. "is-inf": isInf,
  31. "is-nan": isNaN,
  32. "log": math.Log,
  33. "log10": math.Log10,
  34. "log2": math.Log2,
  35. "max": max,
  36. "min": min,
  37. "pow": pow,
  38. "round": round,
  39. "round-to-even": roundToEven,
  40. "sin": math.Sin,
  41. "sinh": math.Sinh,
  42. "sqrt": math.Sqrt,
  43. "tan": math.Tan,
  44. "tanh": math.Tanh,
  45. "trunc": trunc,
  46. }).Ns()
  47. const (
  48. maxInt = int(^uint(0) >> 1)
  49. minInt = -maxInt - 1
  50. )
  51. var absMinInt = new(big.Int).Abs(big.NewInt(int64(minInt)))
  52. func abs(n vals.Num) vals.Num {
  53. switch n := n.(type) {
  54. case int:
  55. if n < 0 {
  56. if n == minInt {
  57. return absMinInt
  58. }
  59. return -n
  60. }
  61. return n
  62. case *big.Int:
  63. if n.Sign() < 0 {
  64. return new(big.Int).Abs(n)
  65. }
  66. return n
  67. case *big.Rat:
  68. if n.Sign() < 0 {
  69. return new(big.Rat).Abs(n)
  70. }
  71. return n
  72. case float64:
  73. return math.Abs(n)
  74. default:
  75. panic("unreachable")
  76. }
  77. }
  78. var (
  79. big1 = big.NewInt(1)
  80. big2 = big.NewInt(2)
  81. )
  82. func ceil(n vals.Num) vals.Num {
  83. return integerize(n,
  84. math.Ceil,
  85. func(n *big.Rat) *big.Int {
  86. q := new(big.Int).Div(n.Num(), n.Denom())
  87. return q.Add(q, big1)
  88. })
  89. }
  90. func floor(n vals.Num) vals.Num {
  91. return integerize(n,
  92. math.Floor,
  93. func(n *big.Rat) *big.Int {
  94. return new(big.Int).Div(n.Num(), n.Denom())
  95. })
  96. }
  97. type isInfOpts struct{ Sign int }
  98. func (opts *isInfOpts) SetDefaultOptions() { opts.Sign = 0 }
  99. func isInf(opts isInfOpts, n vals.Num) bool {
  100. if f, ok := n.(float64); ok {
  101. return math.IsInf(f, opts.Sign)
  102. }
  103. return false
  104. }
  105. func isNaN(n vals.Num) bool {
  106. if f, ok := n.(float64); ok {
  107. return math.IsNaN(f)
  108. }
  109. return false
  110. }
  111. func max(rawNums ...vals.Num) (vals.Num, error) {
  112. if len(rawNums) == 0 {
  113. return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
  114. }
  115. nums := vals.UnifyNums(rawNums, 0)
  116. switch nums := nums.(type) {
  117. case []int:
  118. n := nums[0]
  119. for i := 1; i < len(nums); i++ {
  120. if n < nums[i] {
  121. n = nums[i]
  122. }
  123. }
  124. return n, nil
  125. case []*big.Int:
  126. n := nums[0]
  127. for i := 1; i < len(nums); i++ {
  128. if n.Cmp(nums[i]) < 0 {
  129. n = nums[i]
  130. }
  131. }
  132. return n, nil
  133. case []*big.Rat:
  134. n := nums[0]
  135. for i := 1; i < len(nums); i++ {
  136. if n.Cmp(nums[i]) < 0 {
  137. n = nums[i]
  138. }
  139. }
  140. return n, nil
  141. case []float64:
  142. n := nums[0]
  143. for i := 1; i < len(nums); i++ {
  144. n = math.Max(n, nums[i])
  145. }
  146. return n, nil
  147. default:
  148. panic("unreachable")
  149. }
  150. }
  151. func min(rawNums ...vals.Num) (vals.Num, error) {
  152. if len(rawNums) == 0 {
  153. return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
  154. }
  155. nums := vals.UnifyNums(rawNums, 0)
  156. switch nums := nums.(type) {
  157. case []int:
  158. n := nums[0]
  159. for i := 1; i < len(nums); i++ {
  160. if n > nums[i] {
  161. n = nums[i]
  162. }
  163. }
  164. return n, nil
  165. case []*big.Int:
  166. n := nums[0]
  167. for i := 1; i < len(nums); i++ {
  168. if n.Cmp(nums[i]) > 0 {
  169. n = nums[i]
  170. }
  171. }
  172. return n, nil
  173. case []*big.Rat:
  174. n := nums[0]
  175. for i := 1; i < len(nums); i++ {
  176. if n.Cmp(nums[i]) > 0 {
  177. n = nums[i]
  178. }
  179. }
  180. return n, nil
  181. case []float64:
  182. n := nums[0]
  183. for i := 1; i < len(nums); i++ {
  184. n = math.Min(n, nums[i])
  185. }
  186. return n, nil
  187. default:
  188. panic("unreachable")
  189. }
  190. }
  191. func pow(base, exp vals.Num) vals.Num {
  192. if isExact(base) && isExactInt(exp) {
  193. // Produce exact result
  194. switch exp {
  195. case 0:
  196. return 1
  197. case 1:
  198. return base
  199. case -1:
  200. return new(big.Rat).Inv(vals.PromoteToBigRat(base))
  201. }
  202. exp := vals.PromoteToBigInt(exp)
  203. if isExactInt(base) && exp.Sign() > 0 {
  204. base := vals.PromoteToBigInt(base)
  205. return new(big.Int).Exp(base, exp, nil)
  206. }
  207. base := vals.PromoteToBigRat(base)
  208. if exp.Sign() < 0 {
  209. base = new(big.Rat).Inv(base)
  210. exp = new(big.Int).Neg(exp)
  211. }
  212. return new(big.Rat).SetFrac(
  213. new(big.Int).Exp(base.Num(), exp, nil),
  214. new(big.Int).Exp(base.Denom(), exp, nil))
  215. }
  216. // Produce inexact result
  217. basef := vals.ConvertToFloat64(base)
  218. expf := vals.ConvertToFloat64(exp)
  219. return math.Pow(basef, expf)
  220. }
  221. func isExact(n vals.Num) bool {
  222. switch n.(type) {
  223. case int, *big.Int, *big.Rat:
  224. return true
  225. default:
  226. return false
  227. }
  228. }
  229. func isExactInt(n vals.Num) bool {
  230. switch n.(type) {
  231. case int, *big.Int:
  232. return true
  233. default:
  234. return false
  235. }
  236. }
  237. func round(n vals.Num) vals.Num {
  238. return integerize(n,
  239. math.Round,
  240. func(n *big.Rat) *big.Int {
  241. q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
  242. m = m.Mul(m, big2)
  243. if m.CmpAbs(n.Denom()) < 0 {
  244. return q
  245. }
  246. if n.Sign() < 0 {
  247. return q.Sub(q, big1)
  248. }
  249. return q.Add(q, big1)
  250. })
  251. }
  252. func roundToEven(n vals.Num) vals.Num {
  253. return integerize(n,
  254. math.RoundToEven,
  255. func(n *big.Rat) *big.Int {
  256. q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
  257. m = m.Mul(m, big2)
  258. if diff := m.CmpAbs(n.Denom()); diff < 0 || diff == 0 && q.Bit(0) == 0 {
  259. return q
  260. }
  261. if n.Sign() < 0 {
  262. return q.Sub(q, big1)
  263. }
  264. return q.Add(q, big1)
  265. })
  266. }
  267. func trunc(n vals.Num) vals.Num {
  268. return integerize(n,
  269. math.Trunc,
  270. func(n *big.Rat) *big.Int {
  271. return new(big.Int).Quo(n.Num(), n.Denom())
  272. })
  273. }
  274. func integerize(n vals.Num, fnFloat func(float64) float64, fnRat func(*big.Rat) *big.Int) vals.Num {
  275. switch n := n.(type) {
  276. case int:
  277. return n
  278. case *big.Int:
  279. return n
  280. case *big.Rat:
  281. if n.Denom().IsInt64() && n.Denom().Int64() == 1 {
  282. // Elvish always normalizes *big.Rat with a denominator of 1 to
  283. // *big.Int, but we still try to be defensive here.
  284. return n.Num()
  285. }
  286. return fnRat(n)
  287. case float64:
  288. return fnFloat(n)
  289. default:
  290. panic("unreachable")
  291. }
  292. }