chore: Update GNG implementation and requirements

feat: Update GNG implementation in gng2.py and gng.py to include a progress bar during training. Also, update requirements.txt.
This commit is contained in:
2024-06-03 10:21:45 +02:00
parent 83fb326cb3
commit cd8681d38e
2 changed files with 46 additions and 28 deletions

View File

@ -1,5 +1,5 @@
import numpy as np
import plotly.graph_objects as plt
from progress.bar import IncrementalBar
class Neuron:
def __init__(self, position):
@ -31,6 +31,7 @@ class GrowingNeuralGas:
self.edges.append(Edge(self.nodes[0], self.nodes[1]))
def fit(self, X, num_iterations=1000):
bar = IncrementalBar("Training", max=num_iterations)
for iteration in range(num_iterations):
# Step 1: Select a random input sample
x = X[np.random.randint(len(X))]
@ -78,6 +79,9 @@ class GrowingNeuralGas:
# Step 10: Decrease all errors
for node in self.nodes:
node.error *= self.delta
bar.next()
bar.finish()
def find_edge(self, node1, node2):
for edge in self.edges:

View File

@ -1,6 +1,7 @@
import dash
import plotly.graph_objects as go
from dash import html, dcc
from dash import html, dcc, callback, Input, Output
import dash_bootstrap_components as dbc
from gng2 import GrowingNeuralGas
from sklearn import datasets as sk
@ -10,34 +11,47 @@ dash.register_page(__name__, path="/gng", title="GNG", name="GNG")
X, _ = sk.make_moons(n_samples=200, noise=0.1)
# Create and fit the GNG model
gng = GrowingNeuralGas(input_dim=2)
gng.fit(X, num_iterations=2000)
fig = go.Figure()
for edge in gng.edges:
fig.add_trace(
go.Scatter(
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
mode="lines",
line=dict(width=2, color="black"),
layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")])
@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")])
def generate_gng(n_clicks):
if n_clicks is None:
return go.Figure().update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
)
for node in gng.nodes:
fig.add_trace(
go.Scatter(
x=[node.position[0]],
y=[node.position[1]],
mode="markers",
marker=dict(size=10, color="red"),
gng = GrowingNeuralGas(input_dim=2)
gng.fit(X, num_iterations=20000)
fig = go.Figure()
for edge in gng.edges:
fig.add_trace(
go.Scatter(
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
mode="lines",
line=dict(width=2, color="white"),
)
)
for node in gng.nodes:
fig.add_trace(
go.Scatter(
x=[node.position[0]],
y=[node.position[1]],
mode="markers",
marker=dict(size=10, color="red"),
)
)
fig.update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
fig.update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
layout = html.Div(dcc.Graph(figure=fig))
return fig