diff --git a/main.go b/main.go index 411c7fa..0040cc2 100644 --- a/main.go +++ b/main.go @@ -36,4 +36,5 @@ func main() { for _, task := range tasks { c.Download(task) } + c.WorkerPool.Wait() } diff --git a/utils/worker.go b/utils/worker.go index 7d7116c..9799727 100644 --- a/utils/worker.go +++ b/utils/worker.go @@ -1,6 +1,9 @@ package utils -import "fmt" +import ( + "fmt" + "sync" +) type TaskQueue chan *MultiThreadDownloader @@ -9,6 +12,8 @@ type Worker struct { } type WorkerPool struct { + sync.WaitGroup + WorkerCount int TaskQueue TaskQueue WorkerQueue chan TaskQueue @@ -26,7 +31,7 @@ func NewWorkerPool(WorkerCount int) *WorkerPool { } } -func (w *Worker) Run(wq chan TaskQueue) { +func (w *Worker) Run(wq chan TaskQueue, owner *WorkerPool) { go func() { for { wq <- w.TaskChan @@ -38,6 +43,7 @@ func (w *Worker) Run(wq chan TaskQueue) { return } fmt.Println("下载完成", t.FullPath) + owner.Done() } } }() @@ -46,12 +52,13 @@ func (w *Worker) Run(wq chan TaskQueue) { func (wp *WorkerPool) Start() { for i := 0; i < wp.WorkerCount; i++ { w := NewWorker() - w.Run(wp.WorkerQueue) + w.Run(wp.WorkerQueue, wp) } go func() { for { select { case t := <-wp.TaskQueue: + wp.Add(1) w := <-wp.WorkerQueue w <- t }