builtin_fn_num.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package eval
  2. import (
  3. "fmt"
  4. "math"
  5. "math/big"
  6. "math/rand"
  7. "strconv"
  8. "time"
  9. "src.elv.sh/pkg/eval/errs"
  10. "src.elv.sh/pkg/eval/vals"
  11. )
  12. // Numerical operations.
  13. func init() {
  14. addBuiltinFns(map[string]any{
  15. // Constructor
  16. "float64": toFloat64,
  17. "num": num,
  18. "exact-num": exactNum,
  19. "inexact-num": inexactNum,
  20. // Comparison
  21. "<": lt,
  22. "<=": le,
  23. "==": eqNum,
  24. "!=": ne,
  25. ">": gt,
  26. ">=": ge,
  27. // Arithmetic
  28. "+": add,
  29. "-": sub,
  30. "*": mul,
  31. // Also handles cd /
  32. "/": slash,
  33. "%": rem,
  34. // Random
  35. "rand": rand.Float64,
  36. "randint": randint,
  37. "-randseed": randseed,
  38. "range": rangeFn,
  39. })
  40. // For rand and randint.
  41. rand.Seed(time.Now().UTC().UnixNano())
  42. }
  43. func num(n vals.Num) vals.Num {
  44. // Conversion is actually handled in vals/conversion.go.
  45. return n
  46. }
  47. func exactNum(n vals.Num) (vals.Num, error) {
  48. if f, ok := n.(float64); ok {
  49. r := new(big.Rat).SetFloat64(f)
  50. if r == nil {
  51. return nil, errs.BadValue{What: "argument here",
  52. Valid: "finite float", Actual: vals.ToString(f)}
  53. }
  54. return r, nil
  55. }
  56. return n, nil
  57. }
  58. func inexactNum(f float64) float64 {
  59. return f
  60. }
  61. func toFloat64(f float64) float64 {
  62. return f
  63. }
  64. func lt(nums ...vals.Num) bool {
  65. return chainCompare(nums,
  66. func(a, b int) bool { return a < b },
  67. func(a, b *big.Int) bool { return a.Cmp(b) < 0 },
  68. func(a, b *big.Rat) bool { return a.Cmp(b) < 0 },
  69. func(a, b float64) bool { return a < b })
  70. }
  71. func le(nums ...vals.Num) bool {
  72. return chainCompare(nums,
  73. func(a, b int) bool { return a <= b },
  74. func(a, b *big.Int) bool { return a.Cmp(b) <= 0 },
  75. func(a, b *big.Rat) bool { return a.Cmp(b) <= 0 },
  76. func(a, b float64) bool { return a <= b })
  77. }
  78. func eqNum(nums ...vals.Num) bool {
  79. return chainCompare(nums,
  80. func(a, b int) bool { return a == b },
  81. func(a, b *big.Int) bool { return a.Cmp(b) == 0 },
  82. func(a, b *big.Rat) bool { return a.Cmp(b) == 0 },
  83. func(a, b float64) bool { return a == b })
  84. }
  85. func ne(nums ...vals.Num) bool {
  86. return chainCompare(nums,
  87. func(a, b int) bool { return a != b },
  88. func(a, b *big.Int) bool { return a.Cmp(b) != 0 },
  89. func(a, b *big.Rat) bool { return a.Cmp(b) != 0 },
  90. func(a, b float64) bool { return a != b })
  91. }
  92. func gt(nums ...vals.Num) bool {
  93. return chainCompare(nums,
  94. func(a, b int) bool { return a > b },
  95. func(a, b *big.Int) bool { return a.Cmp(b) > 0 },
  96. func(a, b *big.Rat) bool { return a.Cmp(b) > 0 },
  97. func(a, b float64) bool { return a > b })
  98. }
  99. func ge(nums ...vals.Num) bool {
  100. return chainCompare(nums,
  101. func(a, b int) bool { return a >= b },
  102. func(a, b *big.Int) bool { return a.Cmp(b) >= 0 },
  103. func(a, b *big.Rat) bool { return a.Cmp(b) >= 0 },
  104. func(a, b float64) bool { return a >= b })
  105. }
  106. func chainCompare(nums []vals.Num,
  107. p1 func(a, b int) bool, p2 func(a, b *big.Int) bool,
  108. p3 func(a, b *big.Rat) bool, p4 func(a, b float64) bool) bool {
  109. for i := 0; i < len(nums)-1; i++ {
  110. var r bool
  111. a, b := vals.UnifyNums2(nums[i], nums[i+1], 0)
  112. switch a := a.(type) {
  113. case int:
  114. r = p1(a, b.(int))
  115. case *big.Int:
  116. r = p2(a, b.(*big.Int))
  117. case *big.Rat:
  118. r = p3(a, b.(*big.Rat))
  119. case float64:
  120. r = p4(a, b.(float64))
  121. }
  122. if !r {
  123. return false
  124. }
  125. }
  126. return true
  127. }
  128. func add(rawNums ...vals.Num) vals.Num {
  129. nums := vals.UnifyNums(rawNums, vals.BigInt)
  130. switch nums := nums.(type) {
  131. case []*big.Int:
  132. acc := big.NewInt(0)
  133. for _, num := range nums {
  134. acc.Add(acc, num)
  135. }
  136. return vals.NormalizeBigInt(acc)
  137. case []*big.Rat:
  138. acc := big.NewRat(0, 1)
  139. for _, num := range nums {
  140. acc.Add(acc, num)
  141. }
  142. return vals.NormalizeBigRat(acc)
  143. case []float64:
  144. acc := float64(0)
  145. for _, num := range nums {
  146. acc += num
  147. }
  148. return acc
  149. default:
  150. panic("unreachable")
  151. }
  152. }
  153. func sub(rawNums ...vals.Num) (vals.Num, error) {
  154. if len(rawNums) == 0 {
  155. return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
  156. }
  157. nums := vals.UnifyNums(rawNums, vals.BigInt)
  158. switch nums := nums.(type) {
  159. case []*big.Int:
  160. acc := &big.Int{}
  161. if len(nums) == 1 {
  162. acc.Neg(nums[0])
  163. return acc, nil
  164. }
  165. acc.Set(nums[0])
  166. for _, num := range nums[1:] {
  167. acc.Sub(acc, num)
  168. }
  169. return acc, nil
  170. case []*big.Rat:
  171. acc := &big.Rat{}
  172. if len(nums) == 1 {
  173. acc.Neg(nums[0])
  174. return acc, nil
  175. }
  176. acc.Set(nums[0])
  177. for _, num := range nums[1:] {
  178. acc.Sub(acc, num)
  179. }
  180. return acc, nil
  181. case []float64:
  182. if len(nums) == 1 {
  183. return -nums[0], nil
  184. }
  185. acc := nums[0]
  186. for _, num := range nums[1:] {
  187. acc -= num
  188. }
  189. return acc, nil
  190. default:
  191. panic("unreachable")
  192. }
  193. }
  194. func mul(rawNums ...vals.Num) vals.Num {
  195. hasExact0 := false
  196. hasInf := false
  197. for _, num := range rawNums {
  198. if num == 0 {
  199. hasExact0 = true
  200. }
  201. if f, ok := num.(float64); ok && math.IsInf(f, 0) {
  202. hasInf = true
  203. break
  204. }
  205. }
  206. if hasExact0 && !hasInf {
  207. return 0
  208. }
  209. nums := vals.UnifyNums(rawNums, vals.BigInt)
  210. switch nums := nums.(type) {
  211. case []*big.Int:
  212. acc := big.NewInt(1)
  213. for _, num := range nums {
  214. acc.Mul(acc, num)
  215. }
  216. return vals.NormalizeBigInt(acc)
  217. case []*big.Rat:
  218. acc := big.NewRat(1, 1)
  219. for _, num := range nums {
  220. acc.Mul(acc, num)
  221. }
  222. return vals.NormalizeBigRat(acc)
  223. case []float64:
  224. acc := float64(1)
  225. for _, num := range nums {
  226. acc *= num
  227. }
  228. return acc
  229. default:
  230. panic("unreachable")
  231. }
  232. }
  233. func slash(fm *Frame, args ...vals.Num) error {
  234. if len(args) == 0 {
  235. // cd /
  236. return fm.Evaler.Chdir("/")
  237. }
  238. // Division
  239. result, err := div(args...)
  240. if err != nil {
  241. return err
  242. }
  243. return fm.ValueOutput().Put(vals.FromGo(result))
  244. }
  245. // ErrDivideByZero is thrown when attempting to divide by zero.
  246. var ErrDivideByZero = errs.BadValue{
  247. What: "divisor", Valid: "number other than exact 0", Actual: "exact 0"}
  248. func div(rawNums ...vals.Num) (vals.Num, error) {
  249. for _, num := range rawNums[1:] {
  250. if num == 0 {
  251. return nil, ErrDivideByZero
  252. }
  253. }
  254. if rawNums[0] == 0 {
  255. return 0, nil
  256. }
  257. nums := vals.UnifyNums(rawNums, vals.BigRat)
  258. switch nums := nums.(type) {
  259. case []*big.Rat:
  260. acc := &big.Rat{}
  261. acc.Set(nums[0])
  262. if len(nums) == 1 {
  263. acc.Inv(acc)
  264. return acc, nil
  265. }
  266. for _, num := range nums[1:] {
  267. acc.Quo(acc, num)
  268. }
  269. return acc, nil
  270. case []float64:
  271. acc := nums[0]
  272. if len(nums) == 1 {
  273. return 1 / acc, nil
  274. }
  275. for _, num := range nums[1:] {
  276. acc /= num
  277. }
  278. return acc, nil
  279. default:
  280. panic("unreachable")
  281. }
  282. }
  283. func rem(a, b int) (int, error) {
  284. // TODO: Support other number types
  285. if b == 0 {
  286. return 0, ErrDivideByZero
  287. }
  288. return a % b, nil
  289. }
  290. func randint(args ...int) (int, error) {
  291. var low, high int
  292. switch len(args) {
  293. case 1:
  294. low, high = 0, args[0]
  295. case 2:
  296. low, high = args[0], args[1]
  297. default:
  298. return -1, errs.ArityMismatch{What: "arguments",
  299. ValidLow: 1, ValidHigh: 2, Actual: len(args)}
  300. }
  301. if high <= low {
  302. return 0, errs.BadValue{What: "high value",
  303. Valid: fmt.Sprint("larger than ", low), Actual: strconv.Itoa(high)}
  304. }
  305. return low + rand.Intn(high-low), nil
  306. }
  307. func randseed(x int) { rand.Seed(int64(x)) }
  308. type rangeOpts struct{ Step vals.Num }
  309. // TODO: The default value can only be used implicitly; passing "range
  310. // &step=nil" results in an error.
  311. func (o *rangeOpts) SetDefaultOptions() { o.Step = nil }
  312. func rangeFn(fm *Frame, opts rangeOpts, args ...vals.Num) error {
  313. var rawNums []vals.Num
  314. switch len(args) {
  315. case 1:
  316. rawNums = []vals.Num{0, args[0]}
  317. case 2:
  318. rawNums = []vals.Num{args[0], args[1]}
  319. default:
  320. return errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: 2, Actual: len(args)}
  321. }
  322. if opts.Step != nil {
  323. rawNums = append(rawNums, opts.Step)
  324. }
  325. nums := vals.UnifyNums(rawNums, vals.Int)
  326. out := fm.ValueOutput()
  327. switch nums := nums.(type) {
  328. case []int:
  329. return rangeBuiltinNum(nums, out)
  330. case []*big.Int:
  331. return rangeBigNum(nums, out, bigIntDesc)
  332. case []*big.Rat:
  333. return rangeBigNum(nums, out, bigRatDesc)
  334. case []float64:
  335. return rangeBuiltinNum(nums, out)
  336. default:
  337. panic("unreachable")
  338. }
  339. }
  340. type builtinNum interface{ int | float64 }
  341. func rangeBuiltinNum[T builtinNum](nums []T, out ValueOutput) error {
  342. start, end := nums[0], nums[1]
  343. var step T
  344. if start <= end {
  345. if len(nums) == 3 {
  346. step = nums[2]
  347. if step <= 0 {
  348. return errs.BadValue{
  349. What: "step", Valid: "positive", Actual: vals.ToString(step)}
  350. }
  351. } else {
  352. step = 1
  353. }
  354. for cur := start; cur < end; cur += step {
  355. err := out.Put(vals.FromGo(cur))
  356. if err != nil {
  357. return err
  358. }
  359. if cur+step <= cur {
  360. break
  361. }
  362. }
  363. } else {
  364. if len(nums) == 3 {
  365. step = nums[2]
  366. if step >= 0 {
  367. return errs.BadValue{
  368. What: "step", Valid: "negative", Actual: vals.ToString(step)}
  369. }
  370. } else {
  371. step = -1
  372. }
  373. for cur := start; cur > end; cur += step {
  374. err := out.Put(vals.FromGo(cur))
  375. if err != nil {
  376. return err
  377. }
  378. if cur+step >= cur {
  379. break
  380. }
  381. }
  382. }
  383. return nil
  384. }
  385. type bigNum[T any] interface {
  386. Cmp(T) int
  387. Sign() int
  388. Add(T, T) T
  389. }
  390. type bigNumDesc[T any] struct {
  391. one T
  392. negOne T
  393. newZero func() T
  394. }
  395. var bigIntDesc = bigNumDesc[*big.Int]{
  396. one: big.NewInt(1),
  397. negOne: big.NewInt(-1),
  398. newZero: func() *big.Int { return &big.Int{} },
  399. }
  400. var bigRatDesc = bigNumDesc[*big.Rat]{
  401. one: big.NewRat(1, 1),
  402. negOne: big.NewRat(-1, 1),
  403. newZero: func() *big.Rat { return &big.Rat{} },
  404. }
  405. func rangeBigNum[T bigNum[T]](nums []T, out ValueOutput, d bigNumDesc[T]) error {
  406. start, end := nums[0], nums[1]
  407. var step T
  408. if start.Cmp(end) <= 0 {
  409. if len(nums) == 3 {
  410. step = nums[2]
  411. if step.Sign() <= 0 {
  412. return errs.BadValue{
  413. What: "step", Valid: "positive", Actual: vals.ToString(step)}
  414. }
  415. } else {
  416. step = d.one
  417. }
  418. var cur, next T
  419. for cur = start; cur.Cmp(end) < 0; cur = next {
  420. err := out.Put(vals.FromGo(cur))
  421. if err != nil {
  422. return err
  423. }
  424. next = d.newZero()
  425. next.Add(cur, step)
  426. cur = next
  427. }
  428. } else {
  429. if len(nums) == 3 {
  430. step = nums[2]
  431. if step.Sign() >= 0 {
  432. return errs.BadValue{
  433. What: "step", Valid: "negative", Actual: vals.ToString(step)}
  434. }
  435. } else {
  436. step = d.negOne
  437. }
  438. var cur, next T
  439. for cur = start; cur.Cmp(end) > 0; cur = next {
  440. err := out.Put(vals.FromGo(cur))
  441. if err != nil {
  442. return err
  443. }
  444. next = d.newZero()
  445. next.Add(cur, step)
  446. cur = next
  447. }
  448. }
  449. return nil
  450. }