Skip to content

Topic Modeling Visualization

A class for visualizing topics and their associated documents in a 2D density Map.

This visualizer plots documents and topics on a 2D space with an option to show text labels, contour density representations, and topic centroids. The visualization is useful for understanding the distribution and clustering of topics in a document corpus.

Source code in bunkatopics/visualization/topic_visualizer.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
class TopicVisualizer:
    """
    A class for visualizing topics and their associated documents in a 2D density Map.

    This visualizer plots documents and topics on a 2D space with an option to show text labels,
    contour density representations, and topic centroids. The visualization is useful for
    understanding the distribution and clustering of topics in a document corpus.
    """

    def __init__(
        self,
        show_text=False,
        width=1000,
        height=1000,
        label_size_ratio=100,
        point_size_ratio=150,
        colorscale="delta",
        density: bool = False,
        convex_hull: bool = False,
    ) -> None:
        """
        Initializes the TopicVisualizer with specified parameters.

        Args:
            show_text (bool): If True, text labels are displayed on the plot. Defaults to False.
            width (int): The width of the plot in pixels. Defaults to 1000.
            height (int): The height of the plot in pixels. Defaults to 1000.
            label_size_ratio (int): The size ratio for label text. Defaults to 100.
            colorscale (str): The color scale for contour density representation. Defaults to "delta".
            density (bool): Whether to display a density map
            convex_hull (bool): Whether to display lines around the clusters
        """
        self.show_text = show_text
        self.width = width
        self.height = height
        self.label_size_ratio = label_size_ratio
        self.point_size_ratio = point_size_ratio
        self.colorscale = colorscale
        self.density = density
        self.convex_hull = convex_hull

        self.colorscale_list = [
            "Greys",
            "YlGnBu",
            "Greens",
            "YlOrRd",
            "Bluered",
            "RdBu",
            "Reds",
            "Blues",
            "Picnic",
            "Rainbow",
            "Portland",
            "Jet",
            "Hot",
            "Blackbody",
            "Earth",
            "Electric",
            "Viridis",
            "Cividis",
            "Inferno",
            "Magma",
            "Plasma",
        ]

    def fit_transform(
        self,
        docs: t.List[Document],
        topics: t.List[Topic],
        color: str = None,
    ) -> go.Figure:
        """
        Generates a Plotly figure visualizing the given documents and topics.

        This method processes the documents and topics to create a 2D scatter plot,
        showing the distribution and clustering of topics. It supports displaying text labels,
        contour density, and centroids for topics.

        Args:
            docs (List[Document]): A list of Document objects to be visualized.
            topics (List[Topic]): A list of Topic objects for clustering visualization.
            color (str): The metadata field to use for coloring the documents. Defaults to None.

        Returns:
            go.Figure: A Plotly figure object representing the visualized documents and topics.
        """

        docs_x = [doc.x for doc in docs]
        docs_y = [doc.y for doc in docs]
        docs_topic_id = [doc.topic_id for doc in docs]
        docs_content = [doc.content for doc in docs]
        docs_content_plotly = [wrap_by_word(x, 10) for x in docs_content]

        topics_x = [topic.x_centroid for topic in topics]
        topics_y = [topic.y_centroid for topic in topics]
        topics_name = [topic.name for topic in topics]
        topics_name_plotly = [wrap_by_word(x, 6) for x in topics_name]

        if color is not None:
            self.density = None

        if self.density:
            # Create a figure with Histogram2dContour
            fig_density = go.Figure(
                go.Histogram2dContour(
                    x=docs_x,
                    y=docs_y,
                    colorscale=self.colorscale,
                    showscale=False,
                    hoverinfo="none",
                )
            )

            fig_density.update_traces(
                contours_coloring="fill", contours_showlabels=False
            )

        else:
            fig_density = go.Figure()

        # Update layout settings
        fig_density.update_layout(
            font_size=25,
            width=self.width,
            height=self.height,
            margin=dict(
                t=self.width / 50,
                b=self.width / 50,
                r=self.width / 50,
                l=self.width / 50,
            ),
            title=dict(font=dict(size=self.width / 40)),
        )

        nk = np.empty(shape=(len(docs_content), 3, 1), dtype="object")
        nk[:, 0] = np.array(docs_topic_id).reshape(-1, 1)
        nk[:, 1] = np.array(docs_content_plotly).reshape(-1, 1)

        if color is not None:
            list_color = [x.metadata[color] for x in docs]
            nk[:, 2] = np.array(list_color).reshape(-1, 1)
            hovertemplate = f"<br>%{{customdata[1]}}<br>{color}: %{{customdata[2]}}"
        else:
            hovertemplate = "<br>%{customdata[1]}<br>"

        def extend_color_palette(number_of_categories):
            list_of_colors = px.colors.qualitative.Dark24
            extended_list_of_colors = (
                list_of_colors * (number_of_categories // len(list_of_colors))
                + list_of_colors[: number_of_categories % len(list_of_colors)]
            )
            return extended_list_of_colors

        if color is not None:
            if len(list_color) > 24:
                list_of_colors = extend_color_palette(len(list_color))
            else:
                list_of_colors = px.colors.qualitative.Dark24

        if color is not None:
            if check_list_type(list_color) == "string":
                unique_categories = list(set(list_color))
                colormap = {
                    category: list_of_colors[i]
                    for i, category in enumerate(unique_categories)
                }
                list_color_figure = [colormap[value] for value in list_color]
                colorscale = None
                colorbar = None

            else:
                list_color_figure = list_color
                colorscale = "RdBu"
                colorbar = dict(title=color)

        # if search is not None:
        #     from .visualization_utils import normalize_list

        #     docs_search = self.vectorstore.similarity_search_with_score(
        #         search, k=len(self.vectorstore.get()["documents"])
        #     )
        #     similarity_score = [doc[1] for doc in docs_search]
        #     similarity_score_norm = normalize_list(similarity_score)
        #     similarity_score_norm = [1 - doc for doc in similarity_score_norm]

        #     docs_search = {
        #         "doc_id": [doc[0].metadata["doc_id"] for doc in docs_search],
        #         "score": [score for score in similarity_score_norm],
        #         "page_content": [doc[0].page_content for doc in docs_search],
        #     }

        #     list_color_figure = docs_search["score"]
        #     colorscale = "RdBu"
        #     colorbar = dict(title="Semantic Similarity")

        else:
            list_color_figure = None
            colorscale = None
            colorbar = None

        if self.show_text:
            # Add points with information
            fig_density.add_trace(
                go.Scatter(
                    x=docs_x,
                    y=docs_y,
                    mode="markers",
                    marker=dict(
                        color=list_color_figure,  # Assigning colors based on the list_color
                        size=self.width / self.point_size_ratio,
                        # size=10,  # Adjust the size of the markers as needed
                        opacity=0.5,  # Adjust the opacity of the markers as needed
                        colorscale=colorscale,  # You can specify a colorscale if needed
                        colorbar=colorbar,  # Optional colorbar title
                    ),
                    showlegend=False,
                    customdata=nk,
                    hovertemplate=hovertemplate,
                ),
            )

        if color is not None:
            if check_list_type(list_color) == "string":
                # Create legend based on categories
                legend_items = []
                for category, color_item in colormap.items():
                    legend_items.append(
                        go.Scatter(
                            x=[None],
                            y=[None],
                            mode="markers",
                            marker=dict(color=color_item),
                            name=category,
                        )
                    )

                # Add legend items to the figure
                for item in legend_items:
                    fig_density.add_trace(item)

        # Add centroids labels
        for x, y, label in zip(topics_x, topics_y, topics_name_plotly):
            fig_density.add_annotation(
                x=x,
                y=y,
                text=label,
                showarrow=True,
                arrowhead=1,
                font=dict(
                    family="Courier New, monospace",
                    size=self.width / self.label_size_ratio,
                    color="blue",
                ),
                bordercolor="#c7c7c7",
                borderwidth=self.width / 1000,
                borderpad=self.width / 500,
                bgcolor="white",
                opacity=1,
                arrowcolor="#ff7f0e",
            )

        if self.convex_hull:
            try:
                for topic in topics:
                    # Create a Scatter plot with the convex hull coordinates
                    trace = go.Scatter(
                        x=topic.convex_hull.x_coordinates,
                        y=topic.convex_hull.y_coordinates,
                        mode="lines",
                        name="Convex Hull",
                        line=dict(color="grey", dash="dot"),
                        hoverinfo="none",
                        showlegend=False,
                    )
                    fig_density.add_trace(trace)
            except Exception as e:
                print(e)

        if color is not None:
            fig_density.update_layout(
                legend_title_text=color,
                legend=dict(
                    font=dict(
                        family="Arial",
                        size=int(self.width / 60),  # Adjust font size of the legend
                        color="black",
                    ),
                ),
            )

            fig_density.update_layout(plot_bgcolor="white")

        # fig_density.update_layout(showlegend=True)
        fig_density.update_xaxes(showgrid=False, showticklabels=False, zeroline=False)
        fig_density.update_yaxes(showgrid=False, showticklabels=False, zeroline=False)
        fig_density.update_yaxes(showticklabels=False)

        return fig_density

__init__(show_text=False, width=1000, height=1000, label_size_ratio=100, point_size_ratio=150, colorscale='delta', density=False, convex_hull=False)

Initializes the TopicVisualizer with specified parameters.

Parameters:

Name Type Description Default
show_text bool

If True, text labels are displayed on the plot. Defaults to False.

False
width int

The width of the plot in pixels. Defaults to 1000.

1000
height int

The height of the plot in pixels. Defaults to 1000.

1000
label_size_ratio int

The size ratio for label text. Defaults to 100.

100
colorscale str

The color scale for contour density representation. Defaults to "delta".

'delta'
density bool

Whether to display a density map

False
convex_hull bool

Whether to display lines around the clusters

False
Source code in bunkatopics/visualization/topic_visualizer.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    show_text=False,
    width=1000,
    height=1000,
    label_size_ratio=100,
    point_size_ratio=150,
    colorscale="delta",
    density: bool = False,
    convex_hull: bool = False,
) -> None:
    """
    Initializes the TopicVisualizer with specified parameters.

    Args:
        show_text (bool): If True, text labels are displayed on the plot. Defaults to False.
        width (int): The width of the plot in pixels. Defaults to 1000.
        height (int): The height of the plot in pixels. Defaults to 1000.
        label_size_ratio (int): The size ratio for label text. Defaults to 100.
        colorscale (str): The color scale for contour density representation. Defaults to "delta".
        density (bool): Whether to display a density map
        convex_hull (bool): Whether to display lines around the clusters
    """
    self.show_text = show_text
    self.width = width
    self.height = height
    self.label_size_ratio = label_size_ratio
    self.point_size_ratio = point_size_ratio
    self.colorscale = colorscale
    self.density = density
    self.convex_hull = convex_hull

    self.colorscale_list = [
        "Greys",
        "YlGnBu",
        "Greens",
        "YlOrRd",
        "Bluered",
        "RdBu",
        "Reds",
        "Blues",
        "Picnic",
        "Rainbow",
        "Portland",
        "Jet",
        "Hot",
        "Blackbody",
        "Earth",
        "Electric",
        "Viridis",
        "Cividis",
        "Inferno",
        "Magma",
        "Plasma",
    ]

fit_transform(docs, topics, color=None)

Generates a Plotly figure visualizing the given documents and topics.

This method processes the documents and topics to create a 2D scatter plot, showing the distribution and clustering of topics. It supports displaying text labels, contour density, and centroids for topics.

Parameters:

Name Type Description Default
docs List[Document]

A list of Document objects to be visualized.

required
topics List[Topic]

A list of Topic objects for clustering visualization.

required
color str

The metadata field to use for coloring the documents. Defaults to None.

None

Returns:

Type Description
Figure

go.Figure: A Plotly figure object representing the visualized documents and topics.

Source code in bunkatopics/visualization/topic_visualizer.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def fit_transform(
    self,
    docs: t.List[Document],
    topics: t.List[Topic],
    color: str = None,
) -> go.Figure:
    """
    Generates a Plotly figure visualizing the given documents and topics.

    This method processes the documents and topics to create a 2D scatter plot,
    showing the distribution and clustering of topics. It supports displaying text labels,
    contour density, and centroids for topics.

    Args:
        docs (List[Document]): A list of Document objects to be visualized.
        topics (List[Topic]): A list of Topic objects for clustering visualization.
        color (str): The metadata field to use for coloring the documents. Defaults to None.

    Returns:
        go.Figure: A Plotly figure object representing the visualized documents and topics.
    """

    docs_x = [doc.x for doc in docs]
    docs_y = [doc.y for doc in docs]
    docs_topic_id = [doc.topic_id for doc in docs]
    docs_content = [doc.content for doc in docs]
    docs_content_plotly = [wrap_by_word(x, 10) for x in docs_content]

    topics_x = [topic.x_centroid for topic in topics]
    topics_y = [topic.y_centroid for topic in topics]
    topics_name = [topic.name for topic in topics]
    topics_name_plotly = [wrap_by_word(x, 6) for x in topics_name]

    if color is not None:
        self.density = None

    if self.density:
        # Create a figure with Histogram2dContour
        fig_density = go.Figure(
            go.Histogram2dContour(
                x=docs_x,
                y=docs_y,
                colorscale=self.colorscale,
                showscale=False,
                hoverinfo="none",
            )
        )

        fig_density.update_traces(
            contours_coloring="fill", contours_showlabels=False
        )

    else:
        fig_density = go.Figure()

    # Update layout settings
    fig_density.update_layout(
        font_size=25,
        width=self.width,
        height=self.height,
        margin=dict(
            t=self.width / 50,
            b=self.width / 50,
            r=self.width / 50,
            l=self.width / 50,
        ),
        title=dict(font=dict(size=self.width / 40)),
    )

    nk = np.empty(shape=(len(docs_content), 3, 1), dtype="object")
    nk[:, 0] = np.array(docs_topic_id).reshape(-1, 1)
    nk[:, 1] = np.array(docs_content_plotly).reshape(-1, 1)

    if color is not None:
        list_color = [x.metadata[color] for x in docs]
        nk[:, 2] = np.array(list_color).reshape(-1, 1)
        hovertemplate = f"<br>%{{customdata[1]}}<br>{color}: %{{customdata[2]}}"
    else:
        hovertemplate = "<br>%{customdata[1]}<br>"

    def extend_color_palette(number_of_categories):
        list_of_colors = px.colors.qualitative.Dark24
        extended_list_of_colors = (
            list_of_colors * (number_of_categories // len(list_of_colors))
            + list_of_colors[: number_of_categories % len(list_of_colors)]
        )
        return extended_list_of_colors

    if color is not None:
        if len(list_color) > 24:
            list_of_colors = extend_color_palette(len(list_color))
        else:
            list_of_colors = px.colors.qualitative.Dark24

    if color is not None:
        if check_list_type(list_color) == "string":
            unique_categories = list(set(list_color))
            colormap = {
                category: list_of_colors[i]
                for i, category in enumerate(unique_categories)
            }
            list_color_figure = [colormap[value] for value in list_color]
            colorscale = None
            colorbar = None

        else:
            list_color_figure = list_color
            colorscale = "RdBu"
            colorbar = dict(title=color)

    # if search is not None:
    #     from .visualization_utils import normalize_list

    #     docs_search = self.vectorstore.similarity_search_with_score(
    #         search, k=len(self.vectorstore.get()["documents"])
    #     )
    #     similarity_score = [doc[1] for doc in docs_search]
    #     similarity_score_norm = normalize_list(similarity_score)
    #     similarity_score_norm = [1 - doc for doc in similarity_score_norm]

    #     docs_search = {
    #         "doc_id": [doc[0].metadata["doc_id"] for doc in docs_search],
    #         "score": [score for score in similarity_score_norm],
    #         "page_content": [doc[0].page_content for doc in docs_search],
    #     }

    #     list_color_figure = docs_search["score"]
    #     colorscale = "RdBu"
    #     colorbar = dict(title="Semantic Similarity")

    else:
        list_color_figure = None
        colorscale = None
        colorbar = None

    if self.show_text:
        # Add points with information
        fig_density.add_trace(
            go.Scatter(
                x=docs_x,
                y=docs_y,
                mode="markers",
                marker=dict(
                    color=list_color_figure,  # Assigning colors based on the list_color
                    size=self.width / self.point_size_ratio,
                    # size=10,  # Adjust the size of the markers as needed
                    opacity=0.5,  # Adjust the opacity of the markers as needed
                    colorscale=colorscale,  # You can specify a colorscale if needed
                    colorbar=colorbar,  # Optional colorbar title
                ),
                showlegend=False,
                customdata=nk,
                hovertemplate=hovertemplate,
            ),
        )

    if color is not None:
        if check_list_type(list_color) == "string":
            # Create legend based on categories
            legend_items = []
            for category, color_item in colormap.items():
                legend_items.append(
                    go.Scatter(
                        x=[None],
                        y=[None],
                        mode="markers",
                        marker=dict(color=color_item),
                        name=category,
                    )
                )

            # Add legend items to the figure
            for item in legend_items:
                fig_density.add_trace(item)

    # Add centroids labels
    for x, y, label in zip(topics_x, topics_y, topics_name_plotly):
        fig_density.add_annotation(
            x=x,
            y=y,
            text=label,
            showarrow=True,
            arrowhead=1,
            font=dict(
                family="Courier New, monospace",
                size=self.width / self.label_size_ratio,
                color="blue",
            ),
            bordercolor="#c7c7c7",
            borderwidth=self.width / 1000,
            borderpad=self.width / 500,
            bgcolor="white",
            opacity=1,
            arrowcolor="#ff7f0e",
        )

    if self.convex_hull:
        try:
            for topic in topics:
                # Create a Scatter plot with the convex hull coordinates
                trace = go.Scatter(
                    x=topic.convex_hull.x_coordinates,
                    y=topic.convex_hull.y_coordinates,
                    mode="lines",
                    name="Convex Hull",
                    line=dict(color="grey", dash="dot"),
                    hoverinfo="none",
                    showlegend=False,
                )
                fig_density.add_trace(trace)
        except Exception as e:
            print(e)

    if color is not None:
        fig_density.update_layout(
            legend_title_text=color,
            legend=dict(
                font=dict(
                    family="Arial",
                    size=int(self.width / 60),  # Adjust font size of the legend
                    color="black",
                ),
            ),
        )

        fig_density.update_layout(plot_bgcolor="white")

    # fig_density.update_layout(showlegend=True)
    fig_density.update_xaxes(showgrid=False, showticklabels=False, zeroline=False)
    fig_density.update_yaxes(showgrid=False, showticklabels=False, zeroline=False)
    fig_density.update_yaxes(showticklabels=False)

    return fig_density