diff --git a/pysixtrack/line.py b/pysixtrack/line.py index 7301b16..bd37201 100644 --- a/pysixtrack/line.py +++ b/pysixtrack/line.py @@ -328,6 +328,7 @@ def from_madx_sequence( exact_drift=False, drift_threshold=1e-6, install_apertures=False, + apply_madx_errors=False, ): line = cls(elements=[], element_names=[]) @@ -342,6 +343,9 @@ def from_madx_sequence( ): line.append_element(el, el_name) + if apply_madx_errors: + line._apply_madx_errors(sequence) + return line # error handling (alignment, multipole orders, ...): @@ -366,7 +370,7 @@ def find_element_ids(self, element_name): idx_after_el = idx_el + 1 return idx_el, idx_after_el - def add_offset_error_to(self, element_name, dx=0, dy=0): + def _add_offset_error_to(self, element_name, dx=0, dy=0): idx_el, idx_after_el = self.find_element_ids(element_name) xyshift = elements.XYShift(dx=dx, dy=dy) inv_xyshift = elements.XYShift(dx=-dx, dy=-dy) @@ -375,7 +379,7 @@ def add_offset_error_to(self, element_name, dx=0, dy=0): idx_after_el + 1, inv_xyshift, element_name + "_offset_out" ) - def add_aperture_offset_error_to(self, element_name, arex=0, arey=0): + def _add_aperture_offset_error_to(self, element_name, arex=0, arey=0): idx_el, idx_after_el = self.find_element_ids(element_name) idx_el_aper = idx_after_el - 1 if not self.element_names[idx_el_aper] == element_name + "_aperture": @@ -389,7 +393,7 @@ def add_aperture_offset_error_to(self, element_name, arex=0, arey=0): idx_after_el + 1, inv_xyshift, element_name + "_aperture_offset_out" ) - def add_tilt_error_to(self, element_name, angle): + def _add_tilt_error_to(self, element_name, angle): '''Alignment error of transverse rotation around s-axis. The element corresponding to the given `element_name` gets wrapped by SRotation elements with rotation angle @@ -418,7 +422,7 @@ def add_tilt_error_to(self, element_name, angle): self.insert_element(idx_el, srot, element_name + "_tilt_in") self.insert_element(idx_after_el + 1, inv_srot, element_name + "_tilt_out") - def add_multipole_error_to(self, element_name, knl=[], ksl=[]): + def _add_multipole_error_to(self, element_name, knl=[], ksl=[]): # will raise error if element not present: assert element_name in self.element_names element = self.elements[self.element_names.index(element_name)] @@ -435,13 +439,13 @@ def add_multipole_error_to(self, element_name, knl=[], ksl=[]): for i, component in enumerate(ksl): element.ksl[i] += component - def apply_madx_errors(self, error_table): - """Applies MAD-X error_table (with multipole errors, - dx and dy offset errors and dpsi tilt errors) - to existing elements in this Line instance. + def _apply_madx_errors(self, madx_sequence): + """Applies errors from MAD-X sequence to existing + elements in this Line instance. - Return error_table names which were not found in the - elements of this Line instance (and thus not treated). + Return names of MAD-X elements with existing align_errors + or field_errors which were not found in the elements of + this Line instance (and thus not treated). Example via cpymad: madx = cpymad.madx.Madx() @@ -449,83 +453,58 @@ def apply_madx_errors(self, error_table): # (...set up lattice and errors in cpymad...) seq = madx.sequence.some_lattice - # store already applied errors: - madx.command.esave(file='lattice_errors.err') - madx.command.readtable( - file='lattice_errors.err', table="errors") - errors = madx.table.errors - - pysixtrack_line = Line.from_madx_sequence(seq) - pysixtrack_line.apply_madx_errors(errors) + pysixtrack_line = pysixtrack.Line.from_madx_sequence( + seq, + apply_madx_errors=True + ) """ - max_multipole_err = 0 - # check for errors in table which cannot be treated yet: - for error_type in error_table.keys(): - if error_type == "name": - continue - if any(error_table[error_type]): - if error_type in ["dx", "dy", "dpsi", "arex", "arey"]: - # available alignment error - continue - elif error_type[:1] == "k" and error_type[-1:] == "l": - # available multipole error - order = int("".join(c for c in error_type if c.isdigit())) - max_multipole_err = max(max_multipole_err, order) - else: - print( - f'Warning: MAD-X error type "{error_type}"' - " not implemented yet." - ) - elements_not_found = [] - for i_line, element_name in enumerate(error_table["name"]): + for element, element_name in zip( + madx_sequence.expanded_elements, + madx_sequence.expanded_element_names() + ): if element_name not in self.element_names: - elements_not_found.append(element_name) - continue + if element.align_errors or element.field_errors: + elements_not_found.append(element_name) + continue - # add offset - try: - dx = error_table["dx"][i_line] - except KeyError: - dx = 0 - try: - dy = error_table["dy"][i_line] - except KeyError: - dy = 0 - if dx or dy: - self.add_offset_error_to(element_name, dx, dy) - - # add tilt - try: - dpsi = error_table["dpsi"][i_line] - except KeyError: - dpsi = 0 - if dpsi: - self.add_tilt_error_to(element_name, angle=dpsi / deg2rad) - - # add aperture-only offset - try: - arex = error_table["arex"][i_line] - except KeyError: - arex = 0 - try: - arey = error_table["arey"][i_line] - except KeyError: - arey = 0 - if arex or arey: - self.add_aperture_offset_error_to(element_name, arex, arey) - - # add multipole error - knl = [ - error_table[f"k{o}l"][i_line] - for o in range(max_multipole_err + 1) - ] - ksl = [ - error_table[f"k{o}sl"][i_line] - for o in range(max_multipole_err + 1) - ] - if any(knl) or any(ksl): - self.add_multipole_error_to(element_name, knl, ksl) + if element.align_errors: + # add offset + dx = element.align_errors.dx + dy = element.align_errors.dy + if dx or dy: + self._add_offset_error_to(element_name, dx, dy) + + # add tilt + dpsi = element.align_errors.dpsi + if dpsi: + self._add_tilt_error_to(element_name, angle=dpsi / deg2rad) + + # add aperture-only offset + arex = element.align_errors.arex + arey = element.align_errors.arey + if arex or arey: + self._add_aperture_offset_error_to(element_name, arex, arey) + + # check for errors which cannot be treated yet: + for error_type in dir(element.align_errors): + if not error_type[0] == '_' and \ + error_type not in ['dx', 'dy', 'dpsi', 'arex', + 'arey', 'count', 'index']: + print( + f'Warning: MAD-X error type "{error_type}"' + " not implemented yet." + ) + + if element.field_errors: + # add multipole error + if any(element.field_errors.dkn) or \ + any(element.field_errors.dks): + knl = element.field_errors.dkn + ksl = element.field_errors.dks + knl = knl[:np.amax(np.where(knl)) + 1] # delete trailing zeros + ksl = ksl[:np.amax(np.where(ksl)) + 1] # to keep order low + self._add_multipole_error_to(element_name, knl, ksl) return elements_not_found diff --git a/tests/test_line.py b/tests/test_line.py index d46cb92..41272e0 100644 --- a/tests/test_line.py +++ b/tests/test_line.py @@ -34,24 +34,24 @@ def test_line(): n_elements += 1 assert len(line) == n_elements - line.add_offset_error_to(multipole_name, dx=0, dy=0) + line._add_offset_error_to(multipole_name, dx=0, dy=0) n_elements += 2 assert len(line) == n_elements - line.add_offset_error_to(multipole_name, dx=0.2, dy=-0.003) + line._add_offset_error_to(multipole_name, dx=0.2, dy=-0.003) n_elements += 2 assert len(line) == n_elements - line.add_tilt_error_to(multipole_name, angle=0) + line._add_tilt_error_to(multipole_name, angle=0) n_elements += 2 assert len(line) == n_elements - line.add_tilt_error_to(multipole_name, angle=0.1) + line._add_tilt_error_to(multipole_name, angle=0.1) n_elements += 2 assert len(line) == n_elements - line.add_multipole_error_to(multipole_name, knl=[0, 0.1], ksl=[-0.03, 0.01]) - # line.add_multipole_error_to(drift_exact,knl=[0,0.1],ksl=[-0.03,0.01]) + line._add_multipole_error_to(multipole_name, knl=[0, 0.1], ksl=[-0.03, 0.01]) + # line._add_multipole_error_to(drift_exact,knl=[0,0.1],ksl=[-0.03,0.01]) line_dict = line.to_dict() line = pysixtrack.Line.from_dict(line_dict) diff --git a/tests/test_madx_import.py b/tests/test_madx_import.py index 3b04029..69085bd 100644 --- a/tests/test_madx_import.py +++ b/tests/test_madx_import.py @@ -153,18 +153,16 @@ def test_error_import(): select, flag = error, clear; select, flag = error, pattern = "MQ3"; ealign, dx = 0.00, dy = 0.00, arex = 0.00, arey = 0.00, dpsi = 0.00; - efcomp, DKN = {0.0, 0.0, 0.001, 0.002}, DKS = {0.0, 0.0, 0.003, 0.004}; + efcomp, DKN = {0.0, 0.0, 0.001, 0.002}, DKS = {0.0, 0.0, 0.003, 0.004, 0.005}; select, flag = error, full; ''') seq = madx.sequence.testseq - # store already applied errors: - madx.command.esave(file='lattice_errors.err') - madx.command.readtable(file='lattice_errors.err', table="errors") - os.remove('lattice_errors.err') - errors = madx.table.errors - - pysixtrack_line = pysixtrack.Line.from_madx_sequence(seq, install_apertures=True) - pysixtrack_line.apply_madx_errors(errors) + + pysixtrack_line = pysixtrack.Line.from_madx_sequence( + seq, + install_apertures=True, + apply_madx_errors=True, + ) madx.input('stop;') expected_element_num = ( @@ -224,6 +222,7 @@ def test_error_import(): assert abs(MQ3.knl[3] - 0.002) < 1e-14 assert abs(MQ3.ksl[2] - 0.003) < 1e-14 assert abs(MQ3.ksl[3] - 0.004) < 1e-14 + assert abs(MQ3.ksl[4] - 0.005) < 1e-14 def test_neutral_errors(): @@ -256,7 +255,7 @@ def test_neutral_errors(): USE, SEQUENCE=testseq; - Select, flag=makethin, pattern="MQ1", slice=2; + Select, flag=makethin, pattern="T1", slice=2; makethin, sequence=testseq; use, sequence=testseq; @@ -274,14 +273,12 @@ def test_neutral_errors(): select, flag = error, full; ''') seq = madx.sequence.testseq - # store already applied errors: - madx.command.esave(file='lattice_errors.err') - madx.command.readtable(file='lattice_errors.err', table="errors") - os.remove('lattice_errors.err') - errors = madx.table.errors - - pysixtrack_line = pysixtrack.Line.from_madx_sequence(seq, install_apertures=True) - pysixtrack_line.apply_madx_errors(errors) + + pysixtrack_line = pysixtrack.Line.from_madx_sequence( + seq, + install_apertures=True, + apply_madx_errors=True, + ) madx.input('stop;') initial_x = 0.025 @@ -296,3 +293,152 @@ def test_neutral_errors(): assert abs(particle.x-initial_x) < 1e-14 assert abs(particle.y-initial_y) < 1e-14 + + +def test_error_functionality(): + # check if errors are actually working as intended + cpymad_spec = util.find_spec("cpymad") + if cpymad_spec is None: + print("cpymad is not available - abort test") + sys.exit(0) + + from cpymad.madx import Madx + import numpy as np + + madx = Madx() + + madx.input(''' + T1: Collimator, L=0.0, apertype=CIRCLE, aperture={0.5}; + T2: Marker; + T3: Collimator, L=0.0, apertype=CIRCLE, aperture={0.5}; + + testseq: SEQUENCE, l = 20.0; + T1, at = 5; + T2, at = 10; + T3, at = 15; + ENDSEQUENCE; + + !---the usual stuff + BEAM, PARTICLE=PROTON, ENERGY=7000.0, EXN=2.2e-6, EYN=2.2e-6; + USE, SEQUENCE=testseq; + + !---assign misalignments and field errors + select, flag = error, clear; + select, flag = error, pattern = "T1"; + ealign, dx = 0.01, dy = 0.02, arex = 0.03, arey = 0.04; + select, flag = error, clear; + select, flag = error, pattern = "T3"; + ealign, dx = 0.07, dy = 0.08, dpsi = 0.7, arex = 0.08, arey = 0.09; + select, flag = error, full; + ''') + seq = madx.sequence.testseq + + pysixtrack_line = pysixtrack.Line.from_madx_sequence( + seq, + install_apertures=True, + apply_madx_errors=True, + ) + madx.input('stop;') + + x_init = 0.1*np.random.rand(10) + y_init = 0.1*np.random.rand(10) + particles = pysixtrack.Particles( + x=x_init.copy(), + y=y_init.copy() + ) + + T1_checked = False + T1_aper_checked = False + T2_checked = False + T3_checked = False + T3_aper_checked = False + for element, element_name in zip(pysixtrack_line.elements, + pysixtrack_line.element_names): + ret = element.track(particles) + + if element_name == 't1': + T1_checked = True + assert np.all(abs(particles.x - (x_init - 0.01)) < 1e-14) + assert np.all(abs(particles.y - (y_init - 0.02)) < 1e-14) + if element_name == 't1_aperture': + T1_aper_checked = True + assert np.all(abs(particles.x - (x_init - 0.01 - 0.03)) < 1e-14) + assert np.all(abs(particles.y - (y_init - 0.02 - 0.04)) < 1e-14) + if element_name == 't2': + T2_checked = True + assert np.all(abs(particles.x - x_init) < 1e-14) + assert np.all(abs(particles.y - y_init) < 1e-14) + cospsi = np.cos(0.7) + sinpsi = np.sin(0.7) + if element_name == 't3': + T3_checked = True + assert np.all(abs( + particles.x + - (x_init - 0.07)*cospsi + - (y_init - 0.08)*sinpsi + ) < 1e-14) + assert np.all(abs( + particles.y + + (x_init - 0.07)*sinpsi + - (y_init - 0.08)*cospsi + ) < 1e-14) + if element_name == 't3_aperture': + T3_aper_checked = True + assert np.all(abs( + particles.x + - (x_init - 0.07)*cospsi + - (y_init - 0.08)*sinpsi + - (-0.08) + ) < 1e-14) + assert np.all(abs( + particles.y + + (x_init - 0.07)*sinpsi + - (y_init - 0.08)*cospsi + - (-0.09) + ) < 1e-14) + + if ret is not None: + break + + assert not ret + assert np.all([T1_checked, T1_aper_checked, + T2_checked, T3_checked, T3_aper_checked]) + + +def test_zero_errors(): + # check that zero-errors are loaded without erro + cpymad_spec = util.find_spec("cpymad") + if cpymad_spec is None: + print("cpymad is not available - abort test") + sys.exit(0) + + from cpymad.madx import Madx + + madx = Madx() + madx.input(''' + qd: multipole, knl={0,-0.3}; + qf: multipole, knl={0, 0.3}; + testseq: sequence, l = 1; + qd, at = 0.3; + qf, at = 0.6; + endsequence; + ''') + madx.select(flag='error', pattern='qf') + madx.command.efcomp( + dkn=[0, 0, 0, 0, 0.0, 0.0, 0.0], + dks=[0.0, 0.0, 0, 0] + ) + madx.command.ealign( + dx=0.0, + dy=0.0, + ds=0.0, + DPHI=0.0, + DTHETA=0.0, + DPSI=0.0, + MREX=0.0, + MREY=0.0, + MSCALX=0.0, + MSCALY=0.0, + AREX=0.0, + AREY=0.0 + )