batch_transporter.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package zioutil
  2. // Inspired by transport implementation of buptczq/gotun.
  3. // Some code was copied from buptczq/gotun.
  4. import (
  5. "context"
  6. "io"
  7. "strings"
  8. "sync"
  9. "time"
  10. )
  11. type BatchTransporterGlobal struct {
  12. bufferPool sync.Pool
  13. }
  14. func NewBatchTransporterGlobal(bufsize int) *BatchTransporterGlobal {
  15. return &BatchTransporterGlobal{
  16. bufferPool: sync.Pool{New: func() interface{} {
  17. return make([]byte, bufsize)
  18. }},
  19. }
  20. }
  21. type BatchTransporter struct {
  22. btg *BatchTransporterGlobal
  23. rwc1, rwc2 io.ReadWriteCloser
  24. ReadTime12, ReadTime21 time.Time
  25. Error12, Error21 error
  26. Bytes12, Bytes21 int64
  27. Start, Stop time.Time
  28. }
  29. func NewBatchTrasporter(global *BatchTransporterGlobal, start time.Time, rwc1, rwc2 io.ReadWriteCloser) *BatchTransporter {
  30. return &BatchTransporter{
  31. btg: global,
  32. rwc1: rwc1,
  33. rwc2: rwc2,
  34. ReadTime12: time.Now(),
  35. ReadTime21: time.Now(),
  36. Start: start,
  37. }
  38. }
  39. func (bt *BatchTransporter) copyBuffer(ctx context.Context, dst io.Writer, src io.Reader, buf []byte, is12 bool) (cnt int64, err error) {
  40. nftc := false
  41. for {
  42. select {
  43. case <-ctx.Done():
  44. return
  45. default:
  46. }
  47. nr, errr := src.Read(buf)
  48. if !nftc {
  49. nftc = true
  50. if is12 {
  51. bt.ReadTime12 = time.Now()
  52. } else {
  53. bt.ReadTime21 = time.Now()
  54. }
  55. }
  56. if nr > 0 {
  57. nw, errw := dst.Write(buf[0:nr])
  58. if nw > 0 {
  59. cnt += int64(nw)
  60. }
  61. if errw != nil {
  62. err = errw
  63. break
  64. }
  65. if nr != nw {
  66. err = io.ErrShortWrite
  67. break
  68. }
  69. }
  70. if errr != nil {
  71. if errr != io.EOF {
  72. err = errr
  73. }
  74. break
  75. }
  76. }
  77. return cnt, err
  78. }
  79. func (bt *BatchTransporter) Copy() {
  80. ctx, cancel := context.WithCancel(context.Background())
  81. wg := &sync.WaitGroup{}
  82. wg.Add(2)
  83. go func() {
  84. defer wg.Done()
  85. buf := bt.btg.bufferPool.Get().([]byte)
  86. defer bt.btg.bufferPool.Put(buf)
  87. bt.Bytes12, bt.Error12 = bt.copyBuffer(ctx, bt.rwc2, bt.rwc1, buf, true)
  88. cancel()
  89. bt.rwc2.Close()
  90. }()
  91. go func() {
  92. defer wg.Done()
  93. buf := bt.btg.bufferPool.Get().([]byte)
  94. defer bt.btg.bufferPool.Put(buf)
  95. bt.Bytes21, bt.Error21 = bt.copyBuffer(ctx, bt.rwc1, bt.rwc2, buf, false)
  96. cancel()
  97. bt.rwc1.Close()
  98. }()
  99. wg.Wait()
  100. bt.Stop = time.Now()
  101. }
  102. func filterNetworkConnClosedError(err error) error {
  103. if err == nil {
  104. return nil
  105. }
  106. if strings.Contains(err.Error(), "use of closed network connection") {
  107. return nil
  108. }
  109. return err
  110. }
  111. func (bt *BatchTransporter) NoNetworkConnClosedError() error {
  112. err12 := filterNetworkConnClosedError(bt.Error12)
  113. err21 := filterNetworkConnClosedError(bt.Error21)
  114. if err12 == nil {
  115. err12 = err21
  116. }
  117. return err12
  118. }
  119. func (bt *BatchTransporter) PickError() error {
  120. err12 := bt.Error12
  121. err21 := bt.Error21
  122. if err12 == nil {
  123. err12 = err21
  124. }
  125. return err12
  126. }
  127. func (bt *BatchTransporter) GetError12() error {
  128. return bt.Error12
  129. }
  130. func (bt *BatchTransporter) GetError21() error {
  131. return bt.Error21
  132. }