-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpcnn.py
75 lines (61 loc) · 3.05 KB
/
pcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import networkx as nx
import numpy as np
########### TOGGLE THIS VARIABLE TO SWITCH BETWEEN ADDITIVE AND MULTIPLICATIVE INCREMENT ###########
additive_incerment = True
######################################################################################################
class NodeProbabilities:
def __init__(self, node_id: str):
self.node_id = node_id
self.success_count = {}
self.total_count = {}
def get_probability(self, destination: str) -> float:
if destination not in self.total_count or self.total_count[destination] == 0:
return 0.5 # 50% success probability if no transactions have occurred
return self.success_count.get(destination, 0) / self.total_count[destination]
def update_probabilities(self, destination: str, success: bool):
if destination not in self.total_count:
self.total_count[destination] = 20
self.success_count[destination] = 10
if additive_incerment:
self.total_count[destination] += 1
if success:
self.success_count[destination] += 1
else:
self.total_count[destination] *= 1.1
if success:
self.success_count[destination] *=1.21
class PCNN:
def __init__(self):
self.G = nx.DiGraph()
self.payment_channels = {}
self.node_probabilities = {}
def add_payment_channel(self, source: str, destination: str, deposit: int, capacity: int, policy1: dict, policy2: dict):
channel_id = "-".join(sorted([source, destination]))
print("Adding channel", channel_id)
if channel_id in self.payment_channels:
print("Channel already exists")
return
self.payment_channels[channel_id] = {
source: deposit,
destination: deposit,
capacity: capacity,
}
self.G.add_edge(destination, source, capacity=capacity, policy=policy2)
self.G.add_edge(source, destination, capacity=capacity, policy=policy1)
if source not in self.node_probabilities:
self.node_probabilities[source] = NodeProbabilities(source)
if destination not in self.node_probabilities:
self.node_probabilities[destination] = NodeProbabilities(destination)
def get_probability(self, source: str, destination: str) -> float:
if source in self.node_probabilities:
return self.node_probabilities[source].get_probability(destination)
else:
raise ValueError(f"No record for node {source}")
def update_transaction_success(self, source: str, destination: str, success: bool):
if source in self.node_probabilities:
self.node_probabilities[source].update_probabilities(destination, success)
else:
print(f"No probability record for node {source}")
def get_bid_for_node(self, node, destination) -> float:
# (1 - prob)^(K-1)
return (1 - self.node_probabilities[node].get_probability(destination)) ** (len(self.G.successors(node)) - 1)