123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- package zioutil
- // Inspired by transport implementation of buptczq/gotun.
- // Some code was copied from buptczq/gotun.
- import (
- "context"
- "io"
- "strings"
- "sync"
- "time"
- )
- type BatchTransporterGlobal struct {
- bufferPool sync.Pool
- }
- func NewBatchTransporterGlobal(bufsize int) *BatchTransporterGlobal {
- return &BatchTransporterGlobal{
- bufferPool: sync.Pool{New: func() interface{} {
- return make([]byte, bufsize)
- }},
- }
- }
- type BatchTransporter struct {
- btg *BatchTransporterGlobal
- rwc1, rwc2 io.ReadWriteCloser
- ReadTime12, ReadTime21 time.Time
- Error12, Error21 error
- Bytes12, Bytes21 int64
- Start, Stop time.Time
- }
- func NewBatchTrasporter(global *BatchTransporterGlobal, start time.Time, rwc1, rwc2 io.ReadWriteCloser) *BatchTransporter {
- return &BatchTransporter{
- btg: global,
- rwc1: rwc1,
- rwc2: rwc2,
- ReadTime12: time.Now(),
- ReadTime21: time.Now(),
- Start: start,
- }
- }
- func (bt *BatchTransporter) copyBuffer(ctx context.Context, dst io.Writer, src io.Reader, buf []byte, is12 bool) (cnt int64, err error) {
- nftc := false
- for {
- select {
- case <-ctx.Done():
- return
- default:
- }
- nr, errr := src.Read(buf)
- if !nftc {
- nftc = true
- if is12 {
- bt.ReadTime12 = time.Now()
- } else {
- bt.ReadTime21 = time.Now()
- }
- }
- if nr > 0 {
- nw, errw := dst.Write(buf[0:nr])
- if nw > 0 {
- cnt += int64(nw)
- }
- if errw != nil {
- err = errw
- break
- }
- if nr != nw {
- err = io.ErrShortWrite
- break
- }
- }
- if errr != nil {
- if errr != io.EOF {
- err = errr
- }
- break
- }
- }
- return cnt, err
- }
- func (bt *BatchTransporter) Copy() {
- ctx, cancel := context.WithCancel(context.Background())
- wg := &sync.WaitGroup{}
- wg.Add(2)
- go func() {
- defer wg.Done()
- buf := bt.btg.bufferPool.Get().([]byte)
- defer bt.btg.bufferPool.Put(buf)
- bt.Bytes12, bt.Error12 = bt.copyBuffer(ctx, bt.rwc2, bt.rwc1, buf, true)
- cancel()
- bt.rwc2.Close()
- }()
- go func() {
- defer wg.Done()
- buf := bt.btg.bufferPool.Get().([]byte)
- defer bt.btg.bufferPool.Put(buf)
- bt.Bytes21, bt.Error21 = bt.copyBuffer(ctx, bt.rwc1, bt.rwc2, buf, false)
- cancel()
- bt.rwc1.Close()
- }()
- wg.Wait()
- bt.Stop = time.Now()
- }
- func filterNetworkConnClosedError(err error) error {
- if err == nil {
- return nil
- }
- if strings.Contains(err.Error(), "use of closed network connection") {
- return nil
- }
- return err
- }
- func (bt *BatchTransporter) NoNetworkConnClosedError() error {
- err12 := filterNetworkConnClosedError(bt.Error12)
- err21 := filterNetworkConnClosedError(bt.Error21)
- if err12 == nil {
- err12 = err21
- }
- return err12
- }
- func (bt *BatchTransporter) PickError() error {
- err12 := bt.Error12
- err21 := bt.Error21
- if err12 == nil {
- err12 = err21
- }
- return err12
- }
- func (bt *BatchTransporter) GetError12() error {
- return bt.Error12
- }
- func (bt *BatchTransporter) GetError21() error {
- return bt.Error21
- }
|