From 6d4c2ea60149fa8e701f1036417524b5a4523b0b Mon Sep 17 00:00:00 2001 From: Le Stagiaire Date: Tue, 4 Jun 2024 10:58:38 +0200 Subject: [PATCH] chore: Update GNG implementation and requirements --- 3D_app/gng2.py | 2 +- 3D_app/pages/gng.py | 158 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/3D_app/gng2.py b/3D_app/gng2.py index d2c4a33..aadefe2 100644 --- a/3D_app/gng2.py +++ b/3D_app/gng2.py @@ -31,7 +31,7 @@ class GrowingNeuralGas: self.edges.append(Edge(self.nodes[0], self.nodes[1])) def fit(self, X, num_iterations=1000): - bar = IncrementalBar("Training", max=num_iterations) + bar = IncrementalBar("Training", max=num_iterations, suffix="%(percent)d%% - %(eta)ds/%(elapsed)ds") for iteration in range(num_iterations): # Step 1: Select a random input sample x = X[np.random.randint(len(X))] diff --git a/3D_app/pages/gng.py b/3D_app/pages/gng.py index 698ef74..4563179 100644 --- a/3D_app/pages/gng.py +++ b/3D_app/pages/gng.py @@ -8,50 +8,138 @@ from sklearn import datasets as sk dash.register_page(__name__, path="/gng", title="GNG", name="GNG") # Generate synthetic data -X, _ = sk.make_moons(n_samples=200, noise=0.1) +X, Y = sk.make_moons(n_samples=200, noise=0.1, random_state=0) + + +clics = None # Create and fit the GNG model -layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")]) +layout = html.Div( + [ + dbc.Row( + [ + dbc.Col( + dbc.Button("Generate GNG", id="generate-gng"), + ), + dbc.Col( + [ + dbc.Label("Noise: "), + dbc.Input( + placeholder="Noise", + value=0.1, + type="number", + id="noise", + step=0.1, + ), + ] + ), + dbc.Col( + [ + dbc.Label("Iterations: "), + dbc.Input( + placeholder="Iterations", + value=1000, + type="number", + id="iterations", + ), + ] + ), + dbc.Col( + [ + dbc.Label("Max Nodes: "), + dbc.Input( + placeholder="Max Nodes", + value=200, + type="number", + id="nodes", + ), + ] + ), + ] + ), + dcc.Graph(id="base-graph"), + dcc.Graph(id="gng-graph"), + ] +) -@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")]) -def generate_gng(n_clicks): - if n_clicks is None: - return go.Figure().update_layout( - showlegend=False, - margin=dict(l=0, r=0, t=0, b=0), - xaxis=dict(visible=False), - yaxis=dict(visible=False), - ) + +@callback( + [Output("gng-graph", "figure"), Output("base-graph", "figure")], + [ + Input("generate-gng", "n_clicks"), + Input("noise", "value"), + Input("iterations", "value"), + Input("nodes", "value"), + ], +) +def generate_gng(n_clicks, noise, iterations, nodes): + global clics + X, Y = sk.make_moons(n_samples=200, noise=noise, random_state=0) - gng = GrowingNeuralGas(input_dim=2) - gng.fit(X, num_iterations=20000) - - fig = go.Figure() - for edge in gng.edges: - fig.add_trace( - go.Scatter( - x=[edge.nodes[0].position[0], edge.nodes[1].position[0]], - y=[edge.nodes[0].position[1], edge.nodes[1].position[1]], - mode="lines", - line=dict(width=2, color="white"), - ) + print(X, Y) + + fig2 = go.Figure() + fig2.add_trace( + go.Scatter( + x=X[Y == 0, 0], y=X[Y == 0, 1], mode="markers", marker=dict(color="blue") ) - for node in gng.nodes: - fig.add_trace( - go.Scatter( - x=[node.position[0]], - y=[node.position[1]], - mode="markers", - marker=dict(size=10, color="red"), - ) + ) + fig2.add_trace( + go.Scatter( + x=X[Y == 1, 0], y=X[Y == 1, 1], mode="markers", marker=dict(color="red") ) - - fig.update_layout( + ) + fig2.update_layout( showlegend=False, margin=dict(l=0, r=0, t=0, b=0), xaxis=dict(visible=False), yaxis=dict(visible=False), + title="Base Data", ) - - return fig + + if n_clicks != clics: + gng = GrowingNeuralGas(input_dim=2, max_nodes=nodes) + gng.fit(X, num_iterations=iterations) + + fig = go.Figure() + for edge in gng.edges: + fig.add_trace( + go.Scatter( + x=[edge.nodes[0].position[0], edge.nodes[1].position[0]], + y=[edge.nodes[0].position[1], edge.nodes[1].position[1]], + mode="lines", + line=dict(width=2, color="white"), + ) + ) + for node in gng.nodes: + fig.add_trace( + go.Scatter( + x=[node.position[0]], + y=[node.position[1]], + mode="markers", + marker=dict(size=10, color="red"), + ) + ) + + fig.update_layout( + showlegend=False, + margin=dict(l=0, r=0, t=0, b=0), + xaxis=dict(visible=False), + yaxis=dict(visible=False), + title="GNG Model", + ) + + clics = n_clicks + + return [fig, fig2] + + return [ + go.Figure().update_layout( + showlegend=False, + margin=dict(l=0, r=0, t=0, b=0), + xaxis=dict(visible=False), + yaxis=dict(visible=False), + ), + fig2, + ]