Skip to content

Commit

Permalink
Make interpolation settings consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
recisic committed Jun 18, 2024
1 parent 1b60ef0 commit e19a538
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion geodesic_cv/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def geodesic_interpolation(
n_frame, n_point, dim = xyz_unfolded.shape
rng = np.random.default_rng(seed)

manifold = PointCloud(dim=dim, numpoints=n_point, base=xyz_unfolded[0], alpha=alpha)
manifold = PointCloud(dim=dim, numpoints=n_point, base=xyz_folded[0], alpha=alpha)
xyz_unfolded = manifold.align_mpoint(xyz_unfolded[None]).squeeze(0)
xyz_folded = manifold.align_mpoint(xyz_folded[None]).squeeze(0)

Expand Down
12 changes: 8 additions & 4 deletions scripts/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@
torch.set_default_device(device)

# Trajectory
traj_unfolded = load_xtc(args.xtc_unfolded, pdb_folded=args.pdb_folded)
traj_folded = load_xtc(args.xtc_folded, pdb_folded=args.pdb_folded)
xyz_unfolded = torch.tensor(traj_unfolded.xyz)[:: args.traj_stride]
xyz_folded = torch.tensor(traj_folded.xyz)[:: args.traj_stride]
traj_unfolded = load_xtc(
args.xtc_unfolded, pdb_folded=args.pdb_folded, stride=args.traj_stride
)
traj_folded = load_xtc(
args.xtc_folded, pdb_folded=args.pdb_folded, stride=args.traj_stride
)
xyz_unfolded = torch.tensor(traj_unfolded.xyz)
xyz_folded = torch.tensor(traj_folded.xyz)

# Interpolation
if args.interp_method == "gaussian":
Expand Down

0 comments on commit e19a538

Please sign in to comment.