chore: Update GNG implementation and requirements

This commit is contained in:
2024-06-04 10:58:38 +02:00
parent cd8681d38e
commit 6d4c2ea601
2 changed files with 124 additions and 36 deletions

View File

@ -31,7 +31,7 @@ class GrowingNeuralGas:
self.edges.append(Edge(self.nodes[0], self.nodes[1]))
def fit(self, X, num_iterations=1000):
bar = IncrementalBar("Training", max=num_iterations)
bar = IncrementalBar("Training", max=num_iterations, suffix="%(percent)d%% - %(eta)ds/%(elapsed)ds")
for iteration in range(num_iterations):
# Step 1: Select a random input sample
x = X[np.random.randint(len(X))]

View File

@ -8,50 +8,138 @@ from sklearn import datasets as sk
dash.register_page(__name__, path="/gng", title="GNG", name="GNG")
# Generate synthetic data
X, _ = sk.make_moons(n_samples=200, noise=0.1)
X, Y = sk.make_moons(n_samples=200, noise=0.1, random_state=0)
clics = None
# Create and fit the GNG model
layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")])
layout = html.Div(
[
dbc.Row(
[
dbc.Col(
dbc.Button("Generate GNG", id="generate-gng"),
),
dbc.Col(
[
dbc.Label("Noise: "),
dbc.Input(
placeholder="Noise",
value=0.1,
type="number",
id="noise",
step=0.1,
),
]
),
dbc.Col(
[
dbc.Label("Iterations: "),
dbc.Input(
placeholder="Iterations",
value=1000,
type="number",
id="iterations",
),
]
),
dbc.Col(
[
dbc.Label("Max Nodes: "),
dbc.Input(
placeholder="Max Nodes",
value=200,
type="number",
id="nodes",
),
]
),
]
),
dcc.Graph(id="base-graph"),
dcc.Graph(id="gng-graph"),
]
)
@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")])
def generate_gng(n_clicks):
if n_clicks is None:
return go.Figure().update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
@callback(
[Output("gng-graph", "figure"), Output("base-graph", "figure")],
[
Input("generate-gng", "n_clicks"),
Input("noise", "value"),
Input("iterations", "value"),
Input("nodes", "value"),
],
)
def generate_gng(n_clicks, noise, iterations, nodes):
global clics
X, Y = sk.make_moons(n_samples=200, noise=noise, random_state=0)
gng = GrowingNeuralGas(input_dim=2)
gng.fit(X, num_iterations=20000)
fig = go.Figure()
for edge in gng.edges:
fig.add_trace(
go.Scatter(
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
mode="lines",
line=dict(width=2, color="white"),
)
print(X, Y)
fig2 = go.Figure()
fig2.add_trace(
go.Scatter(
x=X[Y == 0, 0], y=X[Y == 0, 1], mode="markers", marker=dict(color="blue")
)
for node in gng.nodes:
fig.add_trace(
go.Scatter(
x=[node.position[0]],
y=[node.position[1]],
mode="markers",
marker=dict(size=10, color="red"),
)
)
fig2.add_trace(
go.Scatter(
x=X[Y == 1, 0], y=X[Y == 1, 1], mode="markers", marker=dict(color="red")
)
fig.update_layout(
)
fig2.update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
title="Base Data",
)
return fig
if n_clicks != clics:
gng = GrowingNeuralGas(input_dim=2, max_nodes=nodes)
gng.fit(X, num_iterations=iterations)
fig = go.Figure()
for edge in gng.edges:
fig.add_trace(
go.Scatter(
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
mode="lines",
line=dict(width=2, color="white"),
)
)
for node in gng.nodes:
fig.add_trace(
go.Scatter(
x=[node.position[0]],
y=[node.position[1]],
mode="markers",
marker=dict(size=10, color="red"),
)
)
fig.update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
title="GNG Model",
)
clics = n_clicks
return [fig, fig2]
return [
go.Figure().update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(visible=False),
yaxis=dict(visible=False),
),
fig2,
]