Selaa lähdekoodia

[🚀] 添加option

bvbej 9 kuukautta sitten
vanhempi
säilyke
1301ea1e3d

+ 51 - 5
pkg/downloader/controller/controller.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"golang.org/x/net/proxy"
 	"net"
+	"net/http"
 	"os"
 	"time"
 )
@@ -13,14 +14,53 @@ type Controller interface {
 	Write(name string, offset int64, buf []byte) (int, error)
 	Close(name string) error
 	ContextDialer() (proxy.Dialer, error)
+	ContextCookie() http.CookieJar
+	ContextTimeout() time.Duration
+}
+
+type Option func(*option)
+
+type option struct {
+	CookieJar http.CookieJar
+	Timeout   time.Duration
+	Dialer    proxy.Dialer
+}
+
+func WithCookie(cookieJar http.CookieJar) Option {
+	return func(opt *option) {
+		opt.CookieJar = cookieJar
+	}
+}
+
+func WithTimeout(timeout time.Duration) Option {
+	return func(opt *option) {
+		opt.Timeout = timeout
+	}
+}
+
+func WithDialer(dialer proxy.Dialer) Option {
+	return func(opt *option) {
+		opt.Dialer = dialer
+	}
 }
 
 type DefaultController struct {
+	*option
 	Files map[string]*os.File
 }
 
-func NewController() *DefaultController {
-	return &DefaultController{Files: make(map[string]*os.File)}
+func NewController(options ...Option) *DefaultController {
+	opt := new(option)
+	for _, f := range options {
+		f(opt)
+	}
+	if opt.Timeout == 0 {
+		opt.Timeout = time.Second * 30
+	}
+	return &DefaultController{
+		Files:  make(map[string]*os.File),
+		option: opt,
+	}
 }
 
 func (c *DefaultController) Touch(name string, size int64) (file *os.File, err error) {
@@ -56,9 +96,15 @@ func (c *DefaultController) Close(name string) error {
 }
 
 func (c *DefaultController) ContextDialer() (proxy.Dialer, error) {
-	// return proxy.SOCKS5("tpc", "127.0.0.1:9999", nil, nil)
-	var dialer proxy.Dialer
-	return &DialerWarp{dialer: dialer}, nil
+	return &DialerWarp{dialer: c.Dialer}, nil
+}
+
+func (c *DefaultController) ContextCookie() http.CookieJar {
+	return c.CookieJar
+}
+
+func (c *DefaultController) ContextTimeout() time.Duration {
+	return c.Timeout
 }
 
 type DialerWarp struct {

+ 4 - 4
pkg/downloader/downloader.go

@@ -46,9 +46,9 @@ type downloader struct {
 	finishedCh    chan error
 }
 
-func newDownloader(f func() (protocols []string, builder func() fetcher.Fetcher)) *downloader {
+func newDownloader(f func() (protocols []string, builder func() fetcher.Fetcher), options ...controller.Option) *downloader {
 	d := &downloader{
-		DefaultController: controller.NewController(),
+		DefaultController: controller.NewController(options...),
 		finishedCh:        make(chan error, 1),
 	}
 
@@ -259,8 +259,8 @@ func (b *boot) Create(opts *base.Options) <-chan error {
 }
 
 // New 一个文件对应一个实例
-func New() Boot {
+func New(options ...controller.Option) Boot {
 	return &boot{
-		downloader: newDownloader(http.FetcherBuilder),
+		downloader: newDownloader(http.FetcherBuilder, options...),
 	}
 }

+ 32 - 19
pkg/downloader/protocol/http/fetcher.go

@@ -8,10 +8,9 @@ import (
 	"git.bvbej.com/bvbej/base-golang/pkg/downloader/fetcher"
 	"golang.org/x/sync/errgroup"
 	"io"
-	"io/ioutil"
 	"mime"
+	"net"
 	"net/http"
-	"net/http/cookiejar"
 	"net/url"
 	"path"
 	"path/filepath"
@@ -63,11 +62,14 @@ func FetcherBuilder() ([]string, func() fetcher.Fetcher) {
 }
 
 func (f *Fetcher) Resolve(req *base.Request) (*base.Resource, error) {
-	httpReq, err := buildRequest(nil, req)
+	httpReq, err := f.buildRequest(nil, req)
+	if err != nil {
+		return nil, err
+	}
+	client, err := f.buildClient()
 	if err != nil {
 		return nil, err
 	}
-	client := buildClient()
 	// 只访问一个字节,测试资源是否支持Range请求
 	httpReq.Header.Set(base.HttpHeaderRange, fmt.Sprintf(base.HttpHeaderRangeFormat, 0, 0))
 	httpResp, err := client.Do(httpReq)
@@ -75,7 +77,7 @@ func (f *Fetcher) Resolve(req *base.Request) (*base.Resource, error) {
 		return nil, err
 	}
 	// 拿到响应头就关闭,不用加defer
-	httpResp.Body.Close()
+	_ = httpResp.Body.Close()
 	res := &base.Resource{
 		Req:   req,
 		Range: false,
@@ -222,7 +224,7 @@ func (f *Fetcher) fetch() {
 	f.ctx, f.cancel = context.WithCancel(context.Background())
 	eg, _ := errgroup.WithContext(f.ctx)
 	for i := 0; i < f.opts.Connections; i++ {
-		j := i //不加这一行会造成越界报错
+		j := i //TODO loop var per loop(1.22已解决,loop var per-iteration)
 		eg.Go(func() error {
 			return f.fetchChunk(j)
 		})
@@ -249,14 +251,17 @@ func (f *Fetcher) fetchChunk(index int) (err error) {
 	filename := f.filename()
 	chunk := f.chunks[index]
 
-	httpReq, err := buildRequest(f.ctx, f.res.Req)
+	httpReq, err := f.buildRequest(f.ctx, f.res.Req)
 	if err != nil {
 		return err
 	}
-	var (
-		client = buildClient()
-		buf    = make([]byte, 8192)
-	)
+
+	client, err := f.buildClient()
+	if err != nil {
+		return err
+	}
+
+	var buf = make([]byte, 8192)
 
 	// 重试10次
 	for i := 0; i < 10; i++ {
@@ -342,16 +347,24 @@ func (f *Fetcher) fetchChunk(index int) (err error) {
 	return
 }
 
-func buildClient() *http.Client {
-	// Cookie handle
-	jar, _ := cookiejar.New(nil)
-	return &http.Client{
-		Jar:     jar,
-		Timeout: time.Second * 10,
+func (f *Fetcher) buildClient() (*http.Client, error) {
+	dialer, err := f.Ctl.ContextDialer()
+	if err != nil {
+		return nil, err
+	}
+	transport := &http.Transport{
+		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+			return dialer.Dial(network, addr)
+		},
 	}
+	return &http.Client{
+		Jar:       f.Ctl.ContextCookie(),
+		Timeout:   f.Ctl.ContextTimeout(),
+		Transport: transport,
+	}, nil
 }
 
-func buildRequest(ctx context.Context, req *base.Request) (httpReq *http.Request, err error) {
+func (f *Fetcher) buildRequest(ctx context.Context, req *base.Request) (httpReq *http.Request, err error) {
 	reqUrl, err := url.Parse(req.URL)
 	if err != nil {
 		return
@@ -377,7 +390,7 @@ func buildRequest(ctx context.Context, req *base.Request) (httpReq *http.Request
 			}
 		}
 		if extra.Body != "" {
-			body = ioutil.NopCloser(bytes.NewBufferString(extra.Body))
+			body = io.NopCloser(bytes.NewBufferString(extra.Body))
 		}
 	}