diff --git a/app/internal/utils/geoloader.go b/app/internal/utils/geoloader.go index 56cb205e4f..051773124c 100644 --- a/app/internal/utils/geoloader.go +++ b/app/internal/utils/geoloader.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "io" "net/http" "os" @@ -15,6 +16,7 @@ const ( geoipURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat" geositeFilename = "geosite.dat" geositeURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geosite.dat" + geoDlTmpPattern = ".hysteria-geoloader.dlpart.*" geoDefaultUpdateInterval = 7 * 24 * time.Hour // 7 days ) @@ -49,7 +51,7 @@ func (l *GeoLoader) shouldDownload(filename string) bool { } } -func (l *GeoLoader) download(filename, url string) error { +func (l *GeoLoader) downloadAndCheck(filename, url string, checkFunc func(filename string) error) error { l.DownloadFunc(filename, url) resp, err := http.Get(url) @@ -59,16 +61,34 @@ func (l *GeoLoader) download(filename, url string) error { } defer resp.Body.Close() - f, err := os.Create(filename) + f, err := os.CreateTemp(".", geoDlTmpPattern) if err != nil { l.DownloadErrFunc(err) return err } - defer f.Close() + defer os.Remove(f.Name()) _, err = io.Copy(f, resp.Body) - l.DownloadErrFunc(err) - return err + if err != nil { + f.Close() + l.DownloadErrFunc(err) + return err + } + f.Close() + + err = checkFunc(f.Name()) + if err != nil { + l.DownloadErrFunc(fmt.Errorf("integrity check failed: %w", err)) + return err + } + + err = os.Rename(f.Name(), filename) + if err != nil { + l.DownloadErrFunc(fmt.Errorf("rename failed: %w", err)) + return err + } + + return nil } func (l *GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { @@ -82,7 +102,10 @@ func (l *GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { filename = geoipFilename } if autoDL && l.shouldDownload(filename) { - err := l.download(filename, geoipURL) + err := l.downloadAndCheck(filename, geoipURL, func(filename string) error { + _, err := v2geo.LoadGeoIP(filename) + return err + }) if err != nil { return nil, err } @@ -106,7 +129,10 @@ func (l *GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { filename = geositeFilename } if autoDL && l.shouldDownload(filename) { - err := l.download(filename, geositeURL) + err := l.downloadAndCheck(filename, geositeURL, func(filename string) error { + _, err := v2geo.LoadGeoSite(filename) + return err + }) if err != nil { return nil, err }