diff --git a/3D_app/gng2.py b/3D_app/gng2.py index 875b01c..d2c4a33 100644 --- a/3D_app/gng2.py +++ b/3D_app/gng2.py @@ -1,5 +1,5 @@ import numpy as np -import plotly.graph_objects as plt +from progress.bar import IncrementalBar class Neuron: def __init__(self, position): @@ -31,6 +31,7 @@ class GrowingNeuralGas: self.edges.append(Edge(self.nodes[0], self.nodes[1])) def fit(self, X, num_iterations=1000): + bar = IncrementalBar("Training", max=num_iterations) for iteration in range(num_iterations): # Step 1: Select a random input sample x = X[np.random.randint(len(X))] @@ -78,6 +79,9 @@ class GrowingNeuralGas: # Step 10: Decrease all errors for node in self.nodes: node.error *= self.delta + + bar.next() + bar.finish() def find_edge(self, node1, node2): for edge in self.edges: diff --git a/3D_app/pages/gng.py b/3D_app/pages/gng.py index 060ca75..698ef74 100644 --- a/3D_app/pages/gng.py +++ b/3D_app/pages/gng.py @@ -1,6 +1,7 @@ import dash import plotly.graph_objects as go -from dash import html, dcc +from dash import html, dcc, callback, Input, Output +import dash_bootstrap_components as dbc from gng2 import GrowingNeuralGas from sklearn import datasets as sk @@ -10,34 +11,47 @@ dash.register_page(__name__, path="/gng", title="GNG", name="GNG") X, _ = sk.make_moons(n_samples=200, noise=0.1) # Create and fit the GNG model -gng = GrowingNeuralGas(input_dim=2) -gng.fit(X, num_iterations=2000) -fig = go.Figure() -for edge in gng.edges: - fig.add_trace( - go.Scatter( - x=[edge.nodes[0].position[0], edge.nodes[1].position[0]], - y=[edge.nodes[0].position[1], edge.nodes[1].position[1]], - mode="lines", - line=dict(width=2, color="black"), +layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")]) + +@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")]) +def generate_gng(n_clicks): + if n_clicks is None: + return go.Figure().update_layout( + showlegend=False, + margin=dict(l=0, r=0, t=0, b=0), + xaxis=dict(visible=False), + yaxis=dict(visible=False), ) - ) -for node in gng.nodes: - fig.add_trace( - go.Scatter( - x=[node.position[0]], - y=[node.position[1]], - mode="markers", - marker=dict(size=10, color="red"), + + gng = GrowingNeuralGas(input_dim=2) + gng.fit(X, num_iterations=20000) + + fig = go.Figure() + for edge in gng.edges: + fig.add_trace( + go.Scatter( + x=[edge.nodes[0].position[0], edge.nodes[1].position[0]], + y=[edge.nodes[0].position[1], edge.nodes[1].position[1]], + mode="lines", + line=dict(width=2, color="white"), + ) ) + for node in gng.nodes: + fig.add_trace( + go.Scatter( + x=[node.position[0]], + y=[node.position[1]], + mode="markers", + marker=dict(size=10, color="red"), + ) + ) + + fig.update_layout( + showlegend=False, + margin=dict(l=0, r=0, t=0, b=0), + xaxis=dict(visible=False), + yaxis=dict(visible=False), ) -fig.update_layout( - showlegend=False, - margin=dict(l=0, r=0, t=0, b=0), - xaxis=dict(visible=False), - yaxis=dict(visible=False), -) - -layout = html.Div(dcc.Graph(figure=fig)) + return fig