Skip to content

Commit

Permalink
add knot tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Jul 17, 2024
1 parent c903302 commit 4e9af77
Show file tree
Hide file tree
Showing 7 changed files with 953 additions and 31 deletions.
223 changes: 192 additions & 31 deletions kan/MultKAN.py

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions kan/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from .MultKAN import *


def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1):

result = {}
result['test_loss'] = []
result['c'] = []
result['G'] = []
result['id'] = []
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__] = []

def collect(evaluation):
result['test_loss'].append(evaluation['test_loss'])
result['c'].append(evaluation['n_edge'])
result['G'].append(evaluation['n_grid'])
result['id'].append(f'{model.round}.{model.state_id}')
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__].append(metrics[i](model, dataset).item())

for i in range(prune_round):
# train and prune
if i == 0:
model = KAN(width=width, grid=grids[0], seed=seed)
else:
model = model.rewind(f'{i-1}.{2*i}')

model.fit(dataset, steps=steps, lamb=lamb)
model = model.prune(edge_th=edge_th, node_th=node_th)
evaluation = model.evaluate(dataset)
collect(evaluation)

for j in range(refine_round):
model = model.refine(grids[j])
model.fit(dataset, steps=steps)
evaluation = model.evaluate(dataset)
collect(evaluation)

for key in list(result.keys()):
result[key] = np.array(result[key])

return result


def pareto_frontier(x,y):

pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
x_pf = x[pf_id]
y_pf = y[pf_id]

return x_pf, y_pf, pf_id
4 changes: 4 additions & 0 deletions kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def extend_grid(grid, k_extend=0):

value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]

# in case grid is degenerate
value = torch.nan_to_num(value)
return value


Expand Down Expand Up @@ -164,6 +167,7 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) # (in_dim, out_dim, batch, n_coef)
# coef shape: (in_dim, outdim, G+k)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) # y_eval: (in_dim, out_dim, batch, 1)
#print(mat)
coef = torch.linalg.lstsq(mat.to(device), y_eval.to(device),
driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
return coef.to(device)
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

188 changes: 188 additions & 0 deletions tutorials/Example_14_knot_supervised.ipynb

Large diffs are not rendered by default.

163 changes: 163 additions & 0 deletions tutorials/Example_15_knot_unsupervised.ipynb

Large diffs are not rendered by default.

0 comments on commit 4e9af77

Please sign in to comment.