diff --git a/tests/test_tensor/test_fitting.py b/tests/test_tensor/test_fitting.py index 12295cc7..a509a233 100644 --- a/tests/test_tensor/test_fitting.py +++ b/tests/test_tensor/test_fitting.py @@ -12,13 +12,16 @@ @pytest.mark.parametrize("method", ("auto", "dense", "overlap")) -@pytest.mark.parametrize("normalized", ( - True, - False, - "squared", - "infidelity", - "infidelity_sqrt", -)) +@pytest.mark.parametrize( + "normalized", + ( + True, + False, + "squared", + "infidelity", + "infidelity_sqrt", + ), +) def test_tensor_network_distance(method, normalized): n = 6 A = qtn.TN_rand_reg(n=n, reg=3, D=2, phys_dim=2, dtype=complex) @@ -34,87 +37,100 @@ def test_tensor_network_distance(method, normalized): @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False), ("solver", "lstsq"))), - ("als", (("enforce_pos", True),)), - ("tree", ()), + "opts", + [ + dict(method="als", dense_solve=False), + dict(method="als", dense_solve=False, solver="lgmres"), + dict( + method="als", + dense_solve=True, + enforce_pos=False, + solver_dense="lstsq", + ), + dict(method="als", dense_solve=True, enforce_pos=True), + dict(method="tree"), pytest.param( - "autodiff", - (("distance_method", "dense"),), + dict(method="autodiff", distance_method="dense"), marks=requires_autograd, ), pytest.param( - "autodiff", - (("distance_method", "overlap"),), + dict(method="autodiff", distance_method="overlap"), marks=requires_autograd, ), - ), + ], ) @pytest.mark.parametrize("dtype", ("float64", "complex128")) -def test_fit_mps(method, opts, dtype): +def test_fit_mps(opts, dtype): k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype) k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) assert k1.distance_normalized(k2) > 1e-3 - k1.fit_(k2, method=method, progbar=True, **dict(opts)) + k1.fit_(k2, progbar=True, **dict(opts)) assert k1.distance_normalized(k2) < 1e-3 @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False),)), - ("als", (("enforce_pos", True),)), + "opts", + [ + dict(method="als", dense_solve=False), + dict(method="als", dense_solve=False, solver="lgmres"), + dict( + method="als", + dense_solve=True, + enforce_pos=False, + solver_dense="lstsq", + ), + dict(method="als", dense_solve=True, enforce_pos=True), pytest.param( - "autodiff", - (("distance_method", "dense"),), + dict(method="autodiff", distance_method="dense"), marks=requires_autograd, ), pytest.param( - "autodiff", - (("distance_method", "overlap"),), + dict(method="autodiff", distance_method="overlap"), marks=requires_autograd, ), - ), + ], ) @pytest.mark.parametrize("dtype", ("float64", "complex128")) -def test_fit_rand_reg(method, opts, dtype): +def test_fit_rand_reg(opts, dtype): r1 = qtn.TN_rand_reg(5, 4, D=2, seed=666, phys_dim=2, dtype=dtype) k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) assert r1.distance(k2) > 1e-3 - r1.fit_(k2, method=method, progbar=True, **dict(opts)) + r1.fit_(k2, progbar=True, **dict(opts)) assert r1.distance(k2) < 1e-3 @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False),)), - ("als", (("enforce_pos", True),)), - ("tree", ()), + "opts", + [ + dict(method="als", dense_solve=False), + dict(method="als", dense_solve=False, solver="lgmres"), + dict( + method="als", + dense_solve=True, + enforce_pos=False, + solver_dense="lstsq", + ), + dict(method="als", dense_solve=True, enforce_pos=True), + dict(method="tree"), pytest.param( - "autodiff", - (("distance_method", "dense"),), + dict(method="autodiff", distance_method="dense"), marks=requires_autograd, ), pytest.param( - "autodiff", - (("distance_method", "overlap"),), + dict(method="autodiff", distance_method="overlap"), marks=requires_autograd, ), - ), + ], ) @pytest.mark.parametrize("dtype", ("float64", "complex128")) -def test_fit_partial_tags(method, opts, dtype): +def test_fit_partial_tags(opts, dtype): k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype) k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) d0 = k1.distance(k2) tags = ["I0", "I2", "I4"] - k1f = k1.fit( - k2, tol=1e-3, tags=tags, method=method, progbar=True, **dict(opts) - ) + k1f = k1.fit(k2, tol=1e-3, tags=tags, progbar=True, **dict(opts)) assert k1f.distance(k2) < d0 - if method != "tree": + if opts["method"] != "tree": assert (k1f[0] - k1[0]).norm() > 1e-12 assert (k1f[1] - k1[1]).norm() < 1e-12 assert (k1f[2] - k1[2]).norm() > 1e-12