Source code for beam_search.graph_setup

from typing import Iterable, Dict
from graphviz import Digraph


[docs]class GraphGenerator: """Builds a basic graph with graphviz."""
[docs] def __init__(self): self._idx = "0" self._parent = "0" self.graph = Digraph(strict=True) self.graph.node( self._idx, "START", fillcolor="darkolivegreen3", style="filled", fontsize="50", )
[docs] def _inc_idx(self, inc: int = 1): self._idx = str(int(self._idx) + inc)
[docs] def _set_last_chosen(self, new_id): self._parent = str(new_id)
[docs] def complete_conversation(self, final_val): self.create_from_parent( {"GOAL REACHED": final_val}, "darkolivegreen3", "GOAL REACHED" )
[docs] def create_from_parent( self, nodes: Dict[str, float], fillcolor: str, new_parent: str = None ): for node, conf in nodes.items(): edge_color, penwidth = "grey45", "5.0" self._inc_idx() self.graph.node( self._idx, f"{node}\n{conf}", fillcolor=fillcolor, style="filled", fontsize="50", ) if new_parent: if node == new_parent: new_parent_id = self._idx edge_color, penwidth = "forestgreen", "10.0" self.graph.edge( self._parent, self._idx, color=edge_color, penwidth=penwidth ) if new_parent: self._set_last_chosen(new_parent_id)
[docs]class BeamSearchGraphGenerator(GraphGenerator): """Handles building a graph for beam search. Args: k (int): The k value for the beam search. """
[docs] class GraphBeam: """Inner class that holds the id maps for each beam so that any node can be easily referenced. """
[docs] def __init__(self, parent_nodes_id_map: Dict = None): if not parent_nodes_id_map: parent_nodes_id_map = {"START": ["0"]} self.parent_nodes_id_map = { name: [idx for idx in ids] for name, ids in parent_nodes_id_map.items() }
[docs] def __init__(self, k: int): super().__init__() self.beams = [self.GraphBeam() for _ in range(k)]
[docs] def set_last_chosen(self, node: str, beam: int): self._set_last_chosen(self.beams[beam].parent_nodes_id_map[node][-1])
[docs] def create_nodes_highlight_k( self, nodes: Dict[str, float], fillcolor: str, parent: str, beam: int, k_chosen: Iterable[str], ): # have to access the parent ID before potentially making changes to the map to prevent # overwriting in the case where you have a node "A" connected to a parent "A" # (otherwise you would attach the node to itself) parent = self.beams[beam].parent_nodes_id_map[parent][-1] for node, conf in nodes.items(): edge_color, arrowhead, penwidth = "grey45", "none", "5.0" self._inc_idx() self.graph.node( self._idx, f"{node}\n{conf}", fillcolor=fillcolor, style="filled", fontsize="50", ) if node in k_chosen: edge_color, arrowhead, penwidth = "purple", "normal", "10.0" # create the list if it doesn't exist yet, otherwise add to it if node not in self.beams[beam].parent_nodes_id_map: self.beams[beam].parent_nodes_id_map[node] = [] self.beams[beam].parent_nodes_id_map[node].append(self._idx) self.graph.edge( parent, self._idx, color=edge_color, penwidth=penwidth, arrowhead=arrowhead, )