|
@@ -0,0 +1,145 @@
|
|
|
+package zioutil
|
|
|
+
|
|
|
+// Inspired by transport implementation of buptczq/gotun.
|
|
|
+// Some code was copied from buptczq/gotun.
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "io"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+type BatchTransporterGlobal struct {
|
|
|
+ bufferPool sync.Pool
|
|
|
+}
|
|
|
+
|
|
|
+func NewBatchTransporterGlobal(bufsize int) *BatchTransporterGlobal {
|
|
|
+ return &BatchTransporterGlobal{
|
|
|
+ bufferPool: sync.Pool{New: func() interface{} {
|
|
|
+ return make([]byte, bufsize)
|
|
|
+ }},
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type BatchTransporter struct {
|
|
|
+ btg *BatchTransporterGlobal
|
|
|
+ rwc1, rwc2 io.ReadWriteCloser
|
|
|
+ ReadTime12, ReadTime21 time.Time
|
|
|
+ Error12, Error21 error
|
|
|
+ Bytes12, Bytes21 int64
|
|
|
+ Start, Stop time.Time
|
|
|
+}
|
|
|
+
|
|
|
+func NewBatchTrasporter(global *BatchTransporterGlobal, start time.Time, rwc1, rwc2 io.ReadWriteCloser) *BatchTransporter {
|
|
|
+ return &BatchTransporter{
|
|
|
+ btg: global,
|
|
|
+ rwc1: rwc1,
|
|
|
+ rwc2: rwc2,
|
|
|
+ ReadTime12: time.Now(),
|
|
|
+ ReadTime21: time.Now(),
|
|
|
+ Start: start,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) copyBuffer(ctx context.Context, dst io.Writer, src io.Reader, buf []byte, is12 bool) (cnt int64, err error) {
|
|
|
+ nftc := false
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ nr, errr := src.Read(buf)
|
|
|
+ if !nftc {
|
|
|
+ nftc = true
|
|
|
+ if is12 {
|
|
|
+ bt.ReadTime12 = time.Now()
|
|
|
+ } else {
|
|
|
+ bt.ReadTime21 = time.Now()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if nr > 0 {
|
|
|
+ nw, errw := dst.Write(buf[0:nr])
|
|
|
+ if nw > 0 {
|
|
|
+ cnt += int64(nw)
|
|
|
+ }
|
|
|
+ if errw != nil {
|
|
|
+ err = errw
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if nr != nw {
|
|
|
+ err = io.ErrShortWrite
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if errr != nil {
|
|
|
+ if errr != io.EOF {
|
|
|
+ err = errr
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return cnt, err
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) Copy() {
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ wg := &sync.WaitGroup{}
|
|
|
+ wg.Add(2)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ buf := bt.btg.bufferPool.Get().([]byte)
|
|
|
+ defer bt.btg.bufferPool.Put(buf)
|
|
|
+ bt.Bytes12, bt.Error12 = bt.copyBuffer(ctx, bt.rwc2, bt.rwc1, buf, true)
|
|
|
+ cancel()
|
|
|
+ bt.rwc2.Close()
|
|
|
+ }()
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ buf := bt.btg.bufferPool.Get().([]byte)
|
|
|
+ defer bt.btg.bufferPool.Put(buf)
|
|
|
+ bt.Bytes21, bt.Error21 = bt.copyBuffer(ctx, bt.rwc1, bt.rwc2, buf, false)
|
|
|
+ cancel()
|
|
|
+ bt.rwc1.Close()
|
|
|
+ }()
|
|
|
+ wg.Wait()
|
|
|
+ bt.Stop = time.Now()
|
|
|
+}
|
|
|
+
|
|
|
+func filterNetworkConnClosedError(err error) error {
|
|
|
+ if err == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ if strings.Contains(err.Error(), "use of closed network connection") {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) NoNetworkConnClosedError() error {
|
|
|
+ err12 := filterNetworkConnClosedError(bt.Error12)
|
|
|
+ err21 := filterNetworkConnClosedError(bt.Error21)
|
|
|
+ if err12 == nil {
|
|
|
+ err12 = err21
|
|
|
+ }
|
|
|
+ return err12
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) PickError() error {
|
|
|
+ err12 := bt.Error12
|
|
|
+ err21 := bt.Error21
|
|
|
+ if err12 == nil {
|
|
|
+ err12 = err21
|
|
|
+ }
|
|
|
+ return err12
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) GetError12() error {
|
|
|
+ return bt.Error12
|
|
|
+}
|
|
|
+
|
|
|
+func (bt *BatchTransporter) GetError21() error {
|
|
|
+ return bt.Error21
|
|
|
+}
|