Source code for rayoptics.elem.parttree

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright © 2021 Michael J. Hayford
"""Manage connectivity between sequence and element models using a tree.

.. Created on Sat Jan 23 20:15:48 2021

.. codeauthor: Michael J. Hayford
"""
from itertools import zip_longest

from anytree import Node, RenderTree, PreOrderIter
from anytree.exporter import DictExporter
from anytree.importer import DictImporter
from anytree.search import find_by_attr

import rayoptics.elem.elements as ele_module
import rayoptics.elem.sgz2ele as sgz2ele
import rayoptics.oprops.thinlens as thinlens
from rayoptics.util import str_to_class


[docs] class PartTree(): def __init__(self, opt_model, **kwargs): self.opt_model = opt_model self.root_node = Node('root', id=self, tag='#group#root') def __json_encode__(self): attrs = dict(vars(self)) del attrs['opt_model'] def node_attrs(attrs_): attrs = [] for k, v in attrs_: if k == "name" or k == "tag": attrs.append((k, v)) elif k == "id": attrs.append(("id_key", str(id(v)))) return attrs exporter = DictExporter(attriter=node_attrs) attrs['root_node'] = exporter.export(self.root_node) return attrs
[docs] def sync_to_restore(self, opt_model): self.opt_model = opt_model if hasattr(self, 'root_node'): root_node_compressed = self.root_node importer = DictImporter() self.root_node = importer.import_(root_node_compressed) self.root_node.id = self if hasattr(self.root_node, 'id_key'): sync_part_tree_on_restore_idkey(self.opt_model, self.opt_model.ele_model, self.opt_model.seq_model, self.root_node) else: sync_part_tree_on_restore(self.opt_model, self.opt_model.ele_model, self.opt_model.seq_model, self.root_node)
[docs] def update_model(self, **kwargs): seq_model = self.opt_model.seq_model self.handle_object_image_tags(seq_model) self.sync_part_tree_on_update(self.opt_model.ele_model, seq_model, self.root_node) self.sort_tree_using_sequence(seq_model)
[docs] def is_empty(self): if (isinstance(self.root_node, Node) and len(self.root_node.children) == 0): return True else: return False
[docs] def check_consistency(self, seq_model, ele_model): chg_list, sme, eme, ame = find_ele_changes(seq_model, ele_model, self) common_ele, added_ele, removed_ele, modified_ele = chg_list if len(added_ele)==0 and len(removed_ele)==0 and len(modified_ele)==0: return True else: return False
[docs] def init_from_sequence(self, seq_model): """Initialize part tree using a *seq_model*. """ root_node = self.root_node for i, sgz in enumerate(zip_longest(seq_model.ifcs, seq_model.gaps, seq_model.z_dir)): s, gap, z_dir = sgz Node(f'i{i}', id=s, tag='#ifc', parent=root_node) if gap is not None: Node(f'g{i}', id=(gap, z_dir), tag='#gap', parent=root_node)
[docs] def sort_tree_using_sequence(self, seq_model): """Resequence part tree using a *seq_model*. """ def parse_path(path2root, groups): for i, node in enumerate(path2root[1:]): parent_node = path2root[i] if parent_node in groups: if node not in groups[parent_node]: groups[parent_node].append(node) else: groups[parent_node] = [node] groups = {} for i, sgz in enumerate(zip_longest(seq_model.ifcs, seq_model.gaps, seq_model.z_dir)): ifc, gap, z_dir = sgz ifc_node = self.node(ifc) if ifc_node is not None: # path2root = ifc_node.path[:-2] path2root = self.nodes_with_tag( node_list=ifc_node.path, tag='#element#space#airgap#dummyifc#group') parse_path(path2root, groups) if gap is not None: gap_node = self.node((gap, z_dir)) if gap_node is not None: # path2root = gap_node.path[:-2] path2root = self.nodes_with_tag( node_list=gap_node.path, tag='#element#space#airgap#dummyifc#group') parse_path(path2root, groups) for group_node, child_list in groups.items(): group_node.children = child_list
[docs] def add_element_model_to_tree(self, ele_model): for e in ele_model.elements: if hasattr(e, 'tree'): self.add_element_to_tree(e) return self
[docs] def add_element_to_tree(self, e, **kwargs): e_node = e.tree(**kwargs) e_node.name = e.label leaves = e_node.leaves for leaf_node in leaves: dup_node = self.node(leaf_node.id) if dup_node is not None: dup_node.parent = None e_node.parent = self.root_node return e_node
[docs] def remove_element_from_tree(self, e, **kwargs): e_node = self.node(e) e_node.children = [] e_node.parent = None return e_node
[docs] def remove_node(self, e_node, merge=True, **kwargs): """ Remove e_node and all related seq_model dependencies. If merge is True, delete the AirGap following e_node and add it's thickness to the AirGap preceding e_node. """ sm = self.opt_model['seq_model'] nodes = self.nodes_with_tag(tag='#element#assembly#airgap#dummyifc', root=e_node) ifcs = [n.id for n in self.nodes_with_tag(tag='#ifc', root=e_node)] if len(ifcs) > 0: idx_1, idx_k = sm.ifcs.index(ifcs[0]), sm.ifcs.index(ifcs[-1]) else: # don't allow deletion of solo airgaps return if merge: idx_0 = idx_1-1 if idx_1 > 0 else 0 prev_ag, _ = self.parent_object( (sm.gaps[idx_0], sm.z_dir[idx_0]), tag='#space') after_ag, after_agn = self.parent_object( (sm.gaps[idx_k], sm.z_dir[idx_k]), tag='#space') prev_ag.gaps[0].thi += after_ag.gaps[0].thi else: after_agn = None # remove interfaces from seq_model sm.remove_node(idx_1, idx_k, merge, **kwargs) # remove elements associated with the nodes eles = [n.id for n in nodes] for e in eles: self.opt_model['ele_model'].remove_element(e) # detach all of the descendants of e_node and after_agn for node in e_node.descendants: node.parent = None if after_agn: for agn in after_agn.descendants: agn.parent = None e_node.parent = None
[docs] def node(self, obj): """ Return the node paired with `obj`. """ return find_by_attr(self.root_node, name='id', value=obj)
[docs] def obj_by_name(self, name): """ Return the node paired with `obj`. """ return find_by_attr(self.root_node, name='name', value=name).id
[docs] def trim_node(self, obj): """ Remove the branch where `obj` is the sole leaf. """ leaf_node = self.node(obj) parent_node = None if leaf_node: parent_node = leaf_node.parent while parent_node is not None: if len(parent_node.children) > 1: # parent has more than one child, trim leaf_node leaf_node.parent = None break else: # trim leaf_node and continue up the branch leaf_node = parent_node parent_node = leaf_node.parent leaf_node.parent = None
[docs] def parent_node(self, obj, tag='#element#space#dummyifc'): """ Return the parent node for `obj`, filtered by `tag`. """ tags = tag.split('#')[1:] leaf_node = self.node(obj) parent_node = leaf_node.parent if leaf_node else None while parent_node is not None: for t in tags: if t in parent_node.tag: return parent_node parent_node = parent_node.parent return parent_node
[docs] def parent_object(self, obj, tag='#element#space#dummyifc'): """ Return the parent object (and node) for `obj`, filtered by `tag`. """ parent_node = self.parent_node(obj, tag) parent = parent_node.id if parent_node else None return parent, parent_node
[docs] def get_child_filter(self, tag='#element#assembly', not_tag=''): """ Returns a fct that filters a list of nodes to satisfy the tags""" def children_with_tag(children): return self.nodes_with_tag(node_list=children, tag=tag, not_tag=not_tag) return children_with_tag
[docs] def list_tree(self, *args, **kwargs): """ Print a graphical console representation of the tree. The optional arguments are passed through to the by_attr filter. Useful examples or arguments include: - pt.list_tree(lambda node: f"{node.name}: {node.tag}") - pt.list_tree(attrname='tag') """ list_tree_from_node(self.root_node, *args, **kwargs)
[docs] def list_tree_full(self): """ Print a graphical console representation of the tree with tags. """ self.list_tree(lambda node: f"{node.name}: {node.tag}")
[docs] def nodes_with_tag(self, tag='#element', not_tag='', root=None, node_list=None): """ Return a list of nodes that contain the requested `tag`. """ def tag_filter(tags): def filter_tagged_node(node): for t in tags: if t in node.tag: for n_t in not_tags: if n_t in node.tag: return False return True return False return filter_tagged_node tags = tag.split('#')[1:] not_tags = not_tag.split('#')[1:] if root is not None: nodes = [node for node in PreOrderIter(root, filter_=tag_filter(tags))] elif node_list is not None: filter_node = tag_filter(tags) nodes = [node for node in node_list if filter_node(node)] else: root_node = self.root_node nodes = [node for node in PreOrderIter(root_node, filter_=tag_filter(tags))] return nodes
[docs] def list_model(self, tag='#element#assembly#dummyifc'): self.list_tree(childiter=self.get_child_filter(tag=tag))
[docs] def build_pt_sg_lists(self): part_tag = '#assembly' nodes = self.nodes_with_tag(tag=part_tag) asms = [n.id for n in nodes] asm_list = [] asm_dict = {} seq_model = self.opt_model['seq_model'] for asm in asms: for p in asm.parts: ele_def = ele_module.build_ele_def(p, seq_model) asm_list.append(ele_def) asm_dict[ele_def] = asm return asm_list, asm_dict
[docs] def list_pt_sg(self): ele_list, ele_dict = self.build_pt_sg_lists() for elem in ele_list: ele_type, idx_list, gap_list = elem e = ele_dict[elem] print(f"{e.label}: {ele_type[0]} {idx_list} {gap_list}")
[docs] def sync_part_tree_on_update(self, ele_model, seq_model, root_node): """Update node names to track element labels. The node labels for the children of parts are handled here. At the leaf level, the node names encode the interface or gap index. At the element level, the indexing is modified for element flip. """ ele_dict = {e.label: e for e in ele_model.elements} for node in PreOrderIter(root_node): name, tag = node.name, node.tag # handle leaf node labeling if tag == '#ifc': idx = seq_model.ifcs.index(node.id) node.name = f'i{idx}' elif tag == '#gap': gap, z_dir = node.id idx = seq_model.gaps.index(gap) z_dir = seq_model.z_dir[idx] node.id = (gap, z_dir) node.name = f'g{idx}' # handle element children labeling elif tag == '#profile': p_name = node.parent.name e = ele_dict[p_name] idx = int(name[1:])-1 if len(name) > 1 else 0 num_idxs = len(e.idx_list()) idx = (num_idxs-idx-1) if e.is_flipped else idx node.id = e.profile_list()[idx] elif tag == '#thic': p_name = node.parent.name e = ele_dict[p_name] idx = int(name[1:])-1 if len(name) > 1 else 0 node.id = e.gap_list()[idx] # skip the root elif '#root' in tag: pass # update node names to track element labels (finally!) else: if hasattr(node, 'id'): if hasattr(node.id, 'label'): node.name = node.id.label else: print(f"sync_part_tree_on_update: No id attribute: " f"{node.name=}, {node.tag=}")
[docs] def handle_object_image_tags(self, seq_model): """ Ensure nodes for object and image ifcs and gaps are tagged. """ oi_ifc_tags = '#dummyifc#surface' oi_gap_tags = '#space#airgap' self._handle_oi_tag(seq_model.ifcs[0], '#object', not_tag=oi_gap_tags, parent_tag=oi_ifc_tags) self._handle_oi_tag((seq_model.gaps[0], seq_model.z_dir[0]), '#object', not_tag=oi_ifc_tags, parent_tag=oi_gap_tags) self._handle_oi_tag(seq_model.ifcs[-1], '#image', not_tag=oi_gap_tags, parent_tag=oi_ifc_tags) self._handle_oi_tag((seq_model.gaps[-1],seq_model.z_dir[-1]), '#image', not_tag=oi_ifc_tags, parent_tag=oi_gap_tags)
def _handle_oi_tag(self, sm_leaf_id, oi_tag, parent_tag, not_tag=''): """ using sm_leaf_id, find parent a """ nodes = self.nodes_with_tag(tag=oi_tag, not_tag=not_tag) oi_node = self.parent_node(sm_leaf_id, tag=parent_tag) found_it = False for n in nodes: if n != oi_node: n.tag = n.tag.replace(oi_tag, '') else: found_it = True if not found_it and oi_node is not None: oi_node.tag += oi_tag
[docs] def sync_part_tree_on_restore(opt_model, ele_model, seq_model, root_node): ele_dict = {e.label: e for e in ele_model.elements} for node in PreOrderIter(root_node): name, tag = node.name, node.tag if name in ele_dict: node.id = ele_dict[name] elif '#ifc' in tag: idx = int(name[1:]) node.id = seq_model.ifcs[idx] elif '#gap' in tag: idx = int(name[1:]) node.id = (seq_model.gaps[idx], seq_model.z_dir[idx]) elif '#profile' in tag: p_name = node.parent.name e = ele_dict[p_name] try: idx = int(name[1:]) - 1 except ValueError: idx = 0 node.id = e.profile_list()[idx] elif '#thic' in tag: p_name = node.parent.name e = ele_dict[p_name] idx = int(name[1:])-1 if len(name) > 1 else 0 node.id = e.gap_list()[idx]
[docs] def sync_part_tree_on_restore_idkey(opt_model, ele_model, seq_model, root_node): for node in PreOrderIter(root_node): name, tag = node.name, node.tag if node.id_key in opt_model.parts_dict: node.id = opt_model.parts_dict[node.id_key] elif '#ifc' in tag: if '#tl' in tag or name[:2] == 'tl': # ThinElement <- ThinLens e = opt_model.parts_dict[node.parent.id_key] node.id = e.intrfc elif '#di' in tag or name[:2] == 'di': # DummyInterface <- Profile <- Interface e_node = node.parent.parent node.id = e_node.id.ref_ifc else: idx = int(name[1:]) node.id = seq_model.ifcs[idx] elif '#gap' in tag: idx = int(name[1:]) node.id = (seq_model.gaps[idx], seq_model.z_dir[idx]) elif '#profile' in tag: node.id = opt_model.profile_dict[node.id_key] elif '#thic' in tag: e = opt_model.parts_dict[node.parent.id_key] idx = int(name[1:])-1 if len(name) > 1 else 0 node.id = e.gap_list()[idx] for node in PreOrderIter(root_node): delattr(node, 'id_key')
[docs] def sequence_to_elements(seq_model, ele_model, part_tree): """ Parse the seq_model into elements and update ele_model accordingly. """ chg_list, sme, eme, ame = find_ele_changes(seq_model, ele_model, part_tree) common_ele, added_ele, removed_ele, modified_ele = chg_list sme_list, seq_str = sme eme_list, eme_dict = eme if len(added_ele)==0 and len(removed_ele)==0 and len(modified_ele)==0: # no additions, deletions or modifications to existing elements # are necessary is_consistent = True else: # now the lists represent the 3 possible actions: adding, deleting, # and modifying an existing element # elements in the added_ele list are created from the `ele_type` if len(added_ele) > 1: added_ele.sort(key=lambda ae: ae[1][0] if len(ae[1])>0 else ae[2][0]) for ae in added_ele: (ele_token, ele_module, ele_class), *_ = ae e = str_to_class(ele_module, ele_class, ele_def_pkg=(seq_model, ae)) ele_model.add_element(e) idx = e.reference_idx() z_dirs = seq_model.z_dir z_dir = z_dirs[idx] if idx < len(z_dirs) else z_dirs[idx-1] part_tree.add_element_to_tree(e, z_dir=z_dir) # items in the removed_ele list are removed from the ele_model # and part_tree for re in removed_ele: relem = eme_dict[re] ele_model.remove_element(relem) part_tree.remove_element_from_tree(relem) # modified elements use the `sync_to_ele_def` protocol to update an # existing element to a new ele_def for me in modified_ele: existing_ele, new_ele = me e = eme_dict[existing_ele] e_node = part_tree.node(e) # update the element definition e.sync_to_ele_def(seq_model, new_ele) # rebuild subtree new_e_node = e.tree() new_e_node.tag = e_node.tag new_e_node.parent = e_node.parent e_node.parent = None # common elements use the `sync_to_seq` protocol to adjust to any additions # or removals of other elements. for ce in common_ele: eme_dict[ce].sync_to_seq(seq_model) is_consistent = False return (is_consistent, chg_list, sme, eme)
[docs] def find_ele_changes(seq_model, ele_model, part_tree, print_visit=False): """ Parse the seq_model into elements and categorize the changes. Returns: common_ele: list of ele_defs in common between sm and pt added_ele: list of ele_defs for new elements to be created removed_ele: list of ele_defs to be removed modified_ele: list of existing elements to be updated from new ele_defs sme_list: ele_defs obtained by parsing the seq_model seq_str: character encoding of seq_model ifcs and gaps eme_list: ele_defs for current elements in the element model eme_dict: key: ele_def returns the value: element asm_list: ele_defs for current assemblies in the element/part model asm_dict: key: ele_def the value: assembly """ # get sequential model "parse string" seq_str = seq_model.seq_str() if seq_str != '': sgz2ele_tree = sgz2ele.sgz2ele_grammar.parse(seq_str) sgz2ele_sm = sgz2ele.SMVisitor() sgz2ele_sm.do_print_visit = print_visit sgz2ele_visit = sgz2ele_sm.visit(sgz2ele_tree) sme_list = sgz2ele.flatten_visit(sgz2ele_visit) eme_list, eme_dict = ele_model.build_ele_sg_lists() eme_set = set(eme_list) sme_set = set(sme_list) common_ele = list(sme_set.intersection(eme_set)) added_ele = list(sme_set.difference(eme_set)) removed_ele = list(eme_set.difference(sme_set)) # the modified element list is constructed from ele_defs with the same # entity type and whose first gap indices match. modified_ele = [(re, ae) for ae in added_ele for re in removed_ele if (ae[0][2] == re[0][2] and ae[2][0] == re[2][0])] # remove the modified_eles from their original lists for me in modified_ele: re, ae = me added_ele.remove(ae) removed_ele.remove(re) asm_list, asm_dict = part_tree.build_pt_sg_lists() return ((common_ele, added_ele, removed_ele, modified_ele), (sme_list, seq_str), (eme_list, eme_dict), (asm_list, asm_dict)) else: return (([], [], [], []), ([], ''), ([], {}), ([], {}))
[docs] def part_list_from_seq(opt_model, idx1, idx2): """Using the part_tree, return the parts for the input sequence range. """ sm = opt_model['seq_model'] seq = zip_longest(sm.ifcs[idx1:idx2+1], sm.gaps[idx1:idx2], sm.z_dir[idx1:idx2]) part_tree = opt_model['part_tree'] node_set = set() for ifc, gap, z_dir in seq: parent_part = part_tree.parent_node(ifc) node_set.add(parent_part) if gap is not None: parent_part = part_tree.parent_node((gap, z_dir)) node_set.add(parent_part) node_list = sorted(node_set, key=lambda node: node.id.reference_idx()) part_list = [node.id for node in node_list] return part_list, node_list
[docs] def list_tree_all_from_node(node, **kwargs): """ List the tree from `node` with full node output. """ tag_filter = kwargs.pop('childiter', list) print(RenderTree(node, childiter=tag_filter))
[docs] def list_tree_from_node(node, *args, **kwargs): """ List the tree from `node` with attribute filtering. The optional arguments are passed through to the by_attr filter. Useful examples or arguments include: - pt.list_tree(lambda node: f"{node.name}: {node.tag}") - pt.list_tree(attrname='tag') """ tag_filter = kwargs.pop('childiter', list) print(RenderTree(node, childiter=tag_filter).by_attr(*args, **kwargs))