package zioutil // Inspired by transport implementation of buptczq/gotun. // Some code was copied from buptczq/gotun. import ( "context" "io" "strings" "sync" "time" ) type BatchTransporterGlobal struct { bufferPool sync.Pool } func NewBatchTransporterGlobal(bufsize int) *BatchTransporterGlobal { return &BatchTransporterGlobal{ bufferPool: sync.Pool{New: func() interface{} { return make([]byte, bufsize) }}, } } type BatchTransporter struct { btg *BatchTransporterGlobal rwc1, rwc2 io.ReadWriteCloser ReadTime12, ReadTime21 time.Time Error12, Error21 error Bytes12, Bytes21 int64 Start, Stop time.Time } func NewBatchTrasporter(global *BatchTransporterGlobal, start time.Time, rwc1, rwc2 io.ReadWriteCloser) *BatchTransporter { return &BatchTransporter{ btg: global, rwc1: rwc1, rwc2: rwc2, ReadTime12: time.Now(), ReadTime21: time.Now(), Start: start, } } func (bt *BatchTransporter) copyBuffer(ctx context.Context, dst io.Writer, src io.Reader, buf []byte, is12 bool) (cnt int64, err error) { nftc := false for { select { case <-ctx.Done(): return default: } nr, errr := src.Read(buf) if !nftc { nftc = true if is12 { bt.ReadTime12 = time.Now() } else { bt.ReadTime21 = time.Now() } } if nr > 0 { nw, errw := dst.Write(buf[0:nr]) if nw > 0 { cnt += int64(nw) } if errw != nil { err = errw break } if nr != nw { err = io.ErrShortWrite break } } if errr != nil { if errr != io.EOF { err = errr } break } } return cnt, err } func (bt *BatchTransporter) Copy() { ctx, cancel := context.WithCancel(context.Background()) wg := &sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() buf := bt.btg.bufferPool.Get().([]byte) defer bt.btg.bufferPool.Put(buf) bt.Bytes12, bt.Error12 = bt.copyBuffer(ctx, bt.rwc2, bt.rwc1, buf, true) cancel() bt.rwc2.Close() }() go func() { defer wg.Done() buf := bt.btg.bufferPool.Get().([]byte) defer bt.btg.bufferPool.Put(buf) bt.Bytes21, bt.Error21 = bt.copyBuffer(ctx, bt.rwc1, bt.rwc2, buf, false) cancel() bt.rwc1.Close() }() wg.Wait() bt.Stop = time.Now() } func filterNetworkConnClosedError(err error) error { if err == nil { return nil } if strings.Contains(err.Error(), "use of closed network connection") { return nil } return err } func (bt *BatchTransporter) NoNetworkConnClosedError() error { err12 := filterNetworkConnClosedError(bt.Error12) err21 := filterNetworkConnClosedError(bt.Error21) if err12 == nil { err12 = err21 } return err12 } func (bt *BatchTransporter) PickError() error { err12 := bt.Error12 err21 := bt.Error21 if err12 == nil { err12 = err21 } return err12 } func (bt *BatchTransporter) GetError12() error { return bt.Error12 } func (bt *BatchTransporter) GetError21() error { return bt.Error21 }