diff --git a/libwallet/assets_manager.go b/libwallet/assets_manager.go index 441d06c67..bb44c7eaa 100644 --- a/libwallet/assets_manager.go +++ b/libwallet/assets_manager.go @@ -965,7 +965,7 @@ func (mgr *AssetsManager) DeleteDEXData() error { } func (mgr *AssetsManager) WatchBalanceChange(listen func()) { - // Reload wallets unmixed balance and reload UI on new blocks. + // Reload total balance on new tx. txAndBlockNotificationListener := &sharedW.TxAndBlockNotificationListener{ OnTransactionConfirmed: func(_ int, _ string, _ int32) { listen() @@ -974,6 +974,8 @@ func (mgr *AssetsManager) WatchBalanceChange(listen func()) { listen() }, } + + // add tx listener for _, wallet := range mgr.AllWallets() { if !wallet.IsNotificationListenerExist(assetIdentifier) { if err := wallet.AddTxAndBlockNotificationListener(txAndBlockNotificationListener, assetIdentifier); err != nil { @@ -981,10 +983,26 @@ func (mgr *AssetsManager) WatchBalanceChange(listen func()) { } } } + + // add rate listener + rateListener := &ext.RateListener{ + OnRateUpdated: func() { + listen() + }, + } + if !mgr.RateSource.IsRateListenerExist(assetIdentifier) { + if err := mgr.RateSource.AddRateListener(rateListener, assetIdentifier); err != nil { + log.Error("Can't listen rate notification ") + } + } } func (mgr *AssetsManager) RemoveAssetChange() { + // Remove all listener on tx notification for _, wallet := range mgr.AllWallets() { wallet.RemoveTxAndBlockNotificationListener(assetIdentifier) } + + // Remove listener on rate notification + mgr.RateSource.RemoveRateListener(assetIdentifier) } diff --git a/libwallet/ext/rate_source.go b/libwallet/ext/rate_source.go index 7df76deec..1c038d106 100644 --- a/libwallet/ext/rate_source.go +++ b/libwallet/ext/rate_source.go @@ -6,6 +6,7 @@ package ext import ( "context" + "errors" "fmt" "strconv" "strings" @@ -115,6 +116,9 @@ type RateSource interface { GetTicker(market values.Market, cacheOnly bool) *Ticker ToggleStatus(disable bool) ToggleSource(newSource string) error + AddRateListener(listener *RateListener, uniqueIdentifier string) error + RemoveRateListener(uniqueIdentifier string) + IsRateListenerExist(uniqueIdentifier string) bool } // RateListener listens for new tickers and rate source change notifications. @@ -141,6 +145,9 @@ type CommonRateSource struct { lastUpdate time.Time disableConversionExchange func() + + notificationListenersMu sync.RWMutex + ratesListeners map[string]*RateListener } // Used to initialize a rate source. @@ -155,6 +162,7 @@ func NewCommonRateSource(ctx context.Context, source string, disableConversionEx tickers: make(map[values.Market]*Ticker), sourceChanged: make(chan *struct{}), disableConversionExchange: disableConversionExchange, + ratesListeners: make(map[string]*RateListener), } s.getTicker = s.sourceGetTickerFunc(source) s.cond = sync.NewCond(&s.mtx) @@ -205,6 +213,36 @@ func (cs *CommonRateSource) isDisabled() bool { return cs.disabled } +func (cs *CommonRateSource) AddRateListener(listener *RateListener, uniqueIdentifier string) error { + if _, ok := cs.ratesListeners[uniqueIdentifier]; ok { + return errors.New(utils.ErrListenerAlreadyExist) + } + + cs.notificationListenersMu.Lock() + defer cs.notificationListenersMu.Unlock() + cs.ratesListeners[uniqueIdentifier] = listener + return nil +} + +func (cs *CommonRateSource) IsRateListenerExist(uniqueIdentifier string) bool { + _, ok := cs.ratesListeners[uniqueIdentifier] + return ok +} + +func (cs *CommonRateSource) RemoveRateListener(uniqueIdentifier string) { + cs.notificationListenersMu.Lock() + defer cs.notificationListenersMu.Unlock() + delete(cs.ratesListeners, uniqueIdentifier) +} + +func (cs *CommonRateSource) pushlishRateUpdated() { + for _, listener := range cs.ratesListeners { + if listener.OnRateUpdated != nil { + listener.OnRateUpdated() + } + } +} + // ToggleSource changes the rate source to newSource. This method takes some // time to refresh the rates and should be executed a a goroutine. func (cs *CommonRateSource) ToggleSource(newSource string) error { @@ -275,6 +313,7 @@ func (cs *CommonRateSource) Refresh(force bool) { }() defer cs.ratesUpdated(time.Now()) + defer cs.pushlishRateUpdated() tickers := make(map[values.Market]*Ticker) if !force { diff --git a/ui/page/root/home_page.go b/ui/page/root/home_page.go index 25cd7f470..34769b86e 100644 --- a/ui/page/root/home_page.go +++ b/ui/page/root/home_page.go @@ -199,6 +199,7 @@ func (hp *HomePage) OnNavigatedTo() { } hp.AssetsManager.WatchBalanceChange(func() { + fmt.Println("Update Total balance") go hp.CalculateAssetsUSDBalance() }) }