From 4ae3e9f251d7080b152974f66f7b7585a897e577 Mon Sep 17 00:00:00 2001 From: Le Stagiaire Date: Wed, 12 Jun 2024 11:26:47 +0200 Subject: [PATCH] feat: Add config.json for easier application configuration chore: Update GNG implementation and requirements feat: Update GNG implementation in gng2.py and gng.py to include a progress bar during training. Also, update requirements.txt. feat: Started implementation of GNG feat: Sampling function in settings popup --- 3D_app/README.md | 10 ++- 3D_app/config.json | 5 ++ 3D_app/gng2.py | 157 +++++++++++++++++++++++++++++++++--------- 3D_app/main.py | 61 +++++++++++++--- 3D_app/pages/ascan.py | 121 +++++++++++++++++++++----------- 3D_app/pages/gng.py | 44 +++++++++--- 3D_app/pages/home.py | 68 +++++++++--------- 7 files changed, 341 insertions(+), 125 deletions(-) create mode 100644 3D_app/config.json diff --git a/3D_app/README.md b/3D_app/README.md index 004460a..143835b 100644 --- a/3D_app/README.md +++ b/3D_app/README.md @@ -1,7 +1,8 @@ # Installation 1. Install Python requirements: `pip install -r requirements.txt` -2. Launch program: `py main.py` +2. Edit config as you want in `config.json` +3. Launch program: `py main.py` # Changelog @@ -21,3 +22,10 @@ * Ajout d'un menu pour naviguer entre les pages * Ajout des popups pour afficher les graphiques en plein écran * Ajout de la page avec les filtres pour la A-scan + +## V4 + +* Ajout de la possibilité d'ouvrir et sauvegarder des datasets +* Ajout des GNG +* Ajout de config.json pour éditer la configuration de l'application plus facilement +* Ajustement des callback pour éviter les bugs diff --git a/3D_app/config.json b/3D_app/config.json new file mode 100644 index 0000000..94624fc --- /dev/null +++ b/3D_app/config.json @@ -0,0 +1,5 @@ +{ + "port": 8051, + "debug": true, + "app_title": "3D App" +} \ No newline at end of file diff --git a/3D_app/gng2.py b/3D_app/gng2.py index aadefe2..15e59ba 100644 --- a/3D_app/gng2.py +++ b/3D_app/gng2.py @@ -1,18 +1,82 @@ import numpy as np -from progress.bar import IncrementalBar +from progress.bar import ShadyBar + class Neuron: def __init__(self, position): self.position = position self.error = 0.0 + class Edge: def __init__(self, node1, node2): self.nodes = (node1, node2) self.age = 0 + class GrowingNeuralGas: - def __init__(self, input_dim, max_nodes=100, max_age=100, epsilon_b=0.05, epsilon_n=0.006, alpha=0.5, delta=0.995, lambda_=100): + """Growing Neural Gas class""" + + def __init__( + self, + input_dim, + max_nodes=100, + max_age=100, + epsilon_b=0.05, + epsilon_n=0.006, + alpha=0.5, + delta=0.995, + lambda_=100, + ): + """ + Create a new Growing Neural Gas model + + Args: + ---------- + ``input_dim (int)``: + Number of input dimensions + ``max_nodes (int, optional)``: + Number of maximum nodes. Defaults to 100. + ``max_age (int, optional)``: + Maximum age of nodes. Defaults to 100. + ``epsilon_b (float, optional)``: + ??. Defaults to 0.05. + ``epsilon_n (float, optional)``: + ??. Defaults to 0.006. + ``alpha (float, optional)``: + ??. Defaults to 0.5. + ``delta (float, optional)``: + ??. Defaults to 0.995. + ``lambda_ (int, optional)``: + ??. Defaults to 100. + + Returns: + ---------- + ``Nothing`` + + Attributes: + ---------- + ``input_dim (int)``: + Number of input dimensions + ``max_nodes (int)``: + Number of maximum nodes + ``max_age (int)``: + Maximum age of nodes + ``epsilon_b (float)``: + ?? + ``epsilon_n (float)``: + ?? + ``alpha (float)``: + ?? + ``delta (float)``: + ?? + ``lambda_ (int)``: + ?? + ``nodes (List[Neuron])``: + List of nodes + ``edges (List[Edge])``: + List of edges + """ self.input_dim = input_dim self.max_nodes = max_nodes self.max_age = max_age @@ -21,93 +85,124 @@ class GrowingNeuralGas: self.alpha = alpha self.delta = delta self.lambda_ = lambda_ - + self.nodes = [] self.edges = [] - + # Initialize with two random neurons self.nodes.append(Neuron(np.random.rand(input_dim))) self.nodes.append(Neuron(np.random.rand(input_dim))) self.edges.append(Edge(self.nodes[0], self.nodes[1])) - def fit(self, X, num_iterations=1000): - bar = IncrementalBar("Training", max=num_iterations, suffix="%(percent)d%% - %(eta)ds/%(elapsed)ds") + async def fit(self, X, num_iterations=1000): + """ + Train the model on the given data + + Args: + ---------- + ``X (NDArray)``: + The data to train on + ``num_iterations (int, optional)``: + The number of iteration to train the model. Defaults to 1000. + + Returns: + ---------- + ``Nothing`` + """ + bar = ShadyBar( + "Training", + max=num_iterations, + suffix="%(percent).2f%% - %(index)d/%(max)d (%(avg).3fs per iteration)", + ) for iteration in range(num_iterations): # Step 1: Select a random input sample x = X[np.random.randint(len(X))] - + # Step 2: Find the nearest and second nearest neurons - dists = np.array([np.linalg.norm(node.position - x) for node in self.nodes]) - winner_idx = np.argmin(dists) - winner = self.nodes[winner_idx] - dists[winner_idx] = np.inf - second_winner_idx = np.argmin(dists) - second_winner = self.nodes[second_winner_idx] - + winner, second_winner = await self.find_nearest(x) + # Step 3: Increment age of edges connected to the winner for edge in self.edges: if winner in edge.nodes: edge.age += 1 - + # Step 4: Add error to the winner winner.error += np.linalg.norm(winner.position - x) ** 2 - + # Step 5: Move the winner and its topological neighbors winner.position += self.epsilon_b * (x - winner.position) for edge in self.edges: if winner in edge.nodes: other = edge.nodes[0] if edge.nodes[1] is winner else edge.nodes[1] other.position += self.epsilon_n * (x - other.position) - + # Step 6: Connect the winner with the second winner edge = self.find_edge(winner, second_winner) if edge: edge.age = 0 else: self.edges.append(Edge(winner, second_winner)) - + # Step 7: Remove old edges self.edges = [edge for edge in self.edges if edge.age <= self.max_age] - + # Step 8: Remove isolated nodes - self.nodes = [node for node in self.nodes if any(node in edge.nodes for edge in self.edges)] - + self.nodes = [ + node + for node in self.nodes + if any(node in edge.nodes for edge in self.edges) + ] + # Step 9: Insert new nodes if iteration % self.lambda_ == 0 and len(self.nodes) < self.max_nodes: self.insert_node() - + # Step 10: Decrease all errors for node in self.nodes: node.error *= self.delta - + bar.next() bar.finish() + async def find_nearest(self, x): + dists = np.array([np.linalg.norm(node.position - x) for node in self.nodes]) + winner_idx = np.argmin(dists) + winner = self.nodes[winner_idx] + dists[winner_idx] = np.inf + second_winner_idx = np.argmin(dists) + second_winner = self.nodes[second_winner_idx] + return winner, second_winner + def find_edge(self, node1, node2): for edge in self.edges: if node1 in edge.nodes and node2 in edge.nodes: return edge return None - + def insert_node(self): # Find the node with the largest error q = max(self.nodes, key=lambda node: node.error) - + # Find the neighbor with the largest error connected_edges = [edge for edge in self.edges if q in edge.nodes] - f = max((node for edge in connected_edges for node in edge.nodes if node is not q), key=lambda node: node.error) - + f = max( + (node for edge in connected_edges for node in edge.nodes if node is not q), + key=lambda node: node.error, + ) + # Insert a new neuron halfway between q and f r_position = (q.position + f.position) / 2 r = Neuron(r_position) self.nodes.append(r) - + # Remove the edge between q and f and add edges q-r and r-f - self.edges = [edge for edge in self.edges if not (q in edge.nodes and f in edge.nodes)] + self.edges = [ + edge for edge in self.edges if not (q in edge.nodes and f in edge.nodes) + ] self.edges.append(Edge(q, r)) self.edges.append(Edge(r, f)) - + # Decrease the error of q and f q.error *= self.alpha f.error *= self.alpha - r.error = q.error \ No newline at end of file + r.error = q.error diff --git a/3D_app/main.py b/3D_app/main.py index e744d74..928200d 100644 --- a/3D_app/main.py +++ b/3D_app/main.py @@ -2,9 +2,10 @@ import dash from dash import dcc, html, ALL, DiskcacheManager from dash.dependencies import Input, Output, State import dash_bootstrap_components as dbc -from os import listdir, mkdir +from os import listdir, mkdir, getenv from os.path import isfile, join import diskcache +from json import load # on crée l'application @@ -14,6 +15,10 @@ app = dash.Dash( use_pages=True, ) +config = load(open("config.json", "r")) + +app.title = config["app_title"] + cache = diskcache.Cache("./cache") background_callback_manager = DiskcacheManager(cache) @@ -154,7 +159,11 @@ modal_save = dbc.Modal( id="save-format", options=[ {"label": "Save raw dataset", "value": "raw"}, - {"label": "Save filtered dataset", "value": "filt"}, + { + "label": "Save filtered dataset", + "value": "filt", + "disabled": True, + }, ], value="raw", inline=True, @@ -220,7 +229,14 @@ navmenu = html.Div( dbc.Nav( [ dbc.NavLink( - html.Div(page["name"], className="ms-2"), + html.Div( + [ + html.B(page["name"]), + html.Br(), + html.Span(page["description"]), + ], + className="ms-2", + ), href=page["path"], active="exact", ) @@ -239,6 +255,17 @@ navmenu = html.Div( ), ) +save_toast = dbc.Toast( + [html.P("File saved successfully in the saves folder!", className="mb-0")], + id="save-toast", + header="Success", + icon="success", + duration=4000, + is_open=False, + dismissable=True, + style={"position": "fixed", "top": 10, "right": 10, "width": 350}, +) + # on défini la navbar nav_bar = dbc.Navbar( dbc.Container( @@ -275,7 +302,7 @@ nav_bar = dbc.Navbar( dbc.Col( html.Div( [ - html.H3("3D app"), + html.H3(f"{app.title}"), html.P( "IJL - Institut Jean Lamour / project stage M2 EEA 2023-2024", ), @@ -313,6 +340,7 @@ nav_bar = dbc.Navbar( modal_open, modal_save, navmenu, + save_toast ], align="center", style={"width": "100%"}, @@ -391,13 +419,18 @@ def toggle_open(n1, n2, is_open): return not is_open return is_open + @app.callback( Output("save-modal", "is_open"), - [Input("save-button", "n_clicks"), Input("save-close", "n_clicks")], + [ + Input("save-button", "n_clicks"), + Input("save-close", "n_clicks"), + Input("save-save", "n_clicks"), + ], [dash.dependencies.State("save-modal", "is_open")], ) -def toggle_save(n1, n2, is_open): - if n1 or n2: +def toggle_save(n1, n2, n3, is_open): + if n1 or n2 or n3: return not is_open return is_open @@ -419,8 +452,12 @@ def refresh_files(n): if isfile(join("Dataset/saves", file)) ] + @app.callback( - [Output("open-modal", "is_open", allow_duplicate=True), Output("store-files", "data")], + [ + Output("open-modal", "is_open", allow_duplicate=True), + Output("store-files", "data"), + ], Input({"type": "file-item", "index": ALL}, "n_clicks"), State({"type": "file-item", "index": ALL}, "children"), prevent_initial_call=True, @@ -431,15 +468,16 @@ def open_file(n, filenames): return [None, ""] file_index = ctx.triggered[0]["prop_id"].split(".")[0] file_index = eval(file_index) - filename = filenames[file_index['index']] + filename = filenames[file_index["index"]] return [False, filename] + @app.callback( Output("save-format", "options"), [Input("store-filters", "data")], ) def update_save_format(filters): - if filters: + if filters != {}: return [ {"label": "Save raw dataset", "value": "raw"}, {"label": "Save filtered dataset", "value": "filt"}, @@ -449,6 +487,7 @@ def update_save_format(filters): {"label": "Save filtered dataset", "value": "filt", "disabled": True}, ] + # on lance l'application if __name__ == "__main__": - app.run(debug=True, port="8051", threaded=True) + app.run(debug=config["debug"] or False, port=config["port"] or "8051", threaded=True) diff --git a/3D_app/pages/ascan.py b/3D_app/pages/ascan.py index 173a98b..0c8c5c9 100644 --- a/3D_app/pages/ascan.py +++ b/3D_app/pages/ascan.py @@ -12,7 +12,11 @@ from Bscan_Cscan_trait import * dash.register_page( - __name__, path="/ascan", title="A-Scan filters", name="A-Scan filters" + __name__, + path="/ascan", + title="A-Scan filters", + name="A-Scan filters", + description="Apply filters on the A-Scan", ) # on définit le dossier et les fichiers à lire @@ -50,7 +54,7 @@ layout = html.Div( dbc.Select( id="select-ascan-filter1", options=[ - {"label": "transformer du Hilbert", "value": "1"}, + {"label": "Transformer du Hilbert", "value": "1"}, ], value=1, style={"margin-bottom": "15px"}, @@ -63,7 +67,7 @@ layout = html.Div( dbc.Select( id="select-ascan-filter2", options=[ - {"label": "sans filtre ", "value": "2"}, + {"label": "No filter ", "value": "2"}, {"label": "filtre passe bas ", "value": "3"}, {"label": "filtre de moyenne mobile", "value": "4"}, {"label": "filtre adaptatif (wiener)", "value": "5"}, @@ -87,7 +91,7 @@ layout = html.Div( dbc.Select( id="select-ascan-filter3", options=[ - {"label": "sans filtre ", "value": "2"}, + {"label": "No filter ", "value": "2"}, {"label": "filtre passe bas ", "value": "3"}, {"label": "filtre de moyenne mobile", "value": "4"}, {"label": "filtre adaptatif (wiener)", "value": "5"}, @@ -108,17 +112,25 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label( - "applique les filtres selections sur tous les data", - style={"marginRight": "5px"}, - ), - dbc.Button( - id="button-validate-filter", - children=dbc.Spinner( - html.Div("Valider", id="loading"), show_initially=False - ), - color="primary", - style={"marginBottom": "15px"}, + html.Div( + [ + dbc.Label( + "Apply selection on all data", + style={"margin": "auto 0"}, + ), + dbc.Button( + id="button-validate-filter", + children=dbc.Spinner( + html.Div("Apply", id="loading"), + show_initially=False, + ), + color="primary", + ), + ], + style={ + "justifyContent": "space-between", + "display": "flex", + }, ), ], width=3, @@ -128,9 +140,9 @@ layout = html.Div( dbc.Row( [ dbc.Col( - [html.Br(), html.B(" paramètre du 1er filtre ")], + html.B(" 1st filter settings "), width=2, - style={"textAlign": "center"}, + style={"textAlign": "center", "padding": "3.5vh"}, ), dbc.Col( [ @@ -147,11 +159,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("cut off ", html_for="cut off"), + dbc.Label("Cut Off ", html_for="cut off"), dbc.Input( id="input-ascan-solo-cutoff", type="number", - placeholder="cut_off", + placeholder="Cut Off", value=1, step=0.1, ), @@ -160,11 +172,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("order ", html_for="order"), + dbc.Label("Order ", html_for="order"), dbc.Input( id="input-ascan-solo-order", type="number", - placeholder="order", + placeholder="Order", value=1, step=1, ), @@ -173,11 +185,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("window size ", html_for="window size"), + dbc.Label("Window size ", html_for="window size"), dbc.Input( id="input-ascan-solo-windowsize", type="number", - placeholder="window_size", + placeholder="Window_size", value=1, step=1, ), @@ -185,9 +197,9 @@ layout = html.Div( width=1, ), dbc.Col( - [html.Br(), html.B(" paramètre du 2e filtre ")], + html.B(" 2nd filter settings "), width=2, - style={"textAlign": "center"}, + style={"textAlign": "center", "padding": "3.5vh"}, ), dbc.Col( [ @@ -204,11 +216,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("cut off ", html_for="cut off"), + dbc.Label("Cut Off ", html_for="cut off"), dbc.Input( id="input-ascan-solo-cutoff-2", type="number", - placeholder="cut_off", + placeholder="Cut Off", value=1, step=0.1, ), @@ -217,11 +229,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("order ", html_for="order"), + dbc.Label("Order ", html_for="order"), dbc.Input( id="input-ascan-solo-order-2", type="number", - placeholder="order", + placeholder="Order", value=1, step=1, ), @@ -230,11 +242,11 @@ layout = html.Div( ), dbc.Col( [ - dbc.Label("window size ", html_for="window size"), + dbc.Label("Window size ", html_for="window size"), dbc.Input( id="input-ascan-solo-windowsize-2", type="number", - placeholder="window_size", + placeholder="Window_size", value=1, step=1, ), @@ -271,7 +283,7 @@ layout = html.Div( ), ] ), - dbc.Label("x"), + dbc.Label("X"), dcc.Slider( id="layer-slider-ascan-solo-x", min=1, @@ -282,7 +294,7 @@ layout = html.Div( str(i): str(i) for i in range(1, dim_z + 1, max(1, int(dim_z / 20))) }, ), - dbc.Label("y"), + dbc.Label("Y"), dcc.Slider( id="layer-slider-ascan-solo-y", min=1, @@ -293,7 +305,7 @@ layout = html.Div( str(i): str(i) for i in range(1, dim_x + 1, max(1, int(dim_x / 20))) }, ), - dbc.Label("z"), + dbc.Label("Z"), dcc.RangeSlider( id="layer-slider-ascan-solo-z", min=1, @@ -304,10 +316,12 @@ layout = html.Div( str(i): str(i) for i in range(1, dim_y + 1, max(1, int(dim_y / 20))) }, ), + html.Div(id="loading-fullscreen"), ], style={"padding": "20px"}, ) + @callback( Output("store-filters", "data"), [ @@ -428,7 +442,7 @@ def update_filter_values(select_filtre_1, select_filtre_2): State("input-ascan-solo-cutoff-2", "value"), State("input-ascan-solo-order-2", "value"), State("input-ascan-solo-windowsize-2", "value"), - ] + ], ) def update_heatmap_ascan( select_ascan_x, @@ -484,11 +498,23 @@ def update_heatmap_ascan( int(windowsize_filtre_2), ) print("fin du traitement") - bouton = "Valider" + bouton = "Apply" if n_clicks != None: - data_traits= Cscant(volume,int(selec_transforme_hilbert),int(select_filtre_1),int(select_filtre_2), - float(fs_filtre_1),float(cutoff_filtre_1),int(order_filtre_1),int(windowsize_filtre_1),float(fs_filtre_2),float(cutoff_filtre_2),int(order_filtre_2),int(windowsize_filtre_2)) - bouton = "Valider" + data_traits = Cscant( + volume, + int(selec_transforme_hilbert), + int(select_filtre_1), + int(select_filtre_2), + float(fs_filtre_1), + float(cutoff_filtre_1), + int(order_filtre_1), + int(windowsize_filtre_1), + float(fs_filtre_2), + float(cutoff_filtre_2), + int(order_filtre_2), + int(windowsize_filtre_2), + ) + bouton = "Apply" fig = px.line(title="A-scan") new_trace = go.Scatter(y=data_avec_traitement, mode="lines", name=" Ascan trait ") fig.add_trace(new_trace) @@ -498,10 +524,21 @@ def update_heatmap_ascan( fig.add_trace(new_trace) fig.update_layout(xaxis_title="indix", yaxis_title="amplitude") + data_bscan = Bscant( + volume[select_ascan_y - 1, select_ascan_z[0] : select_ascan_z[1], :], + int(selec_transforme_hilbert), + int(select_filtre_1), + int(select_filtre_2), + float(fs_filtre_1), + float(cutoff_filtre_1), + int(order_filtre_1), + int(windowsize_filtre_1), + float(fs_filtre_2), + float(cutoff_filtre_2), + int(order_filtre_2), + int(windowsize_filtre_2), + ) - data_bscan=Bscant(volume[select_ascan_y - 1, select_ascan_z[0] : select_ascan_z[1], :],int(selec_transforme_hilbert),int(select_filtre_1),int(select_filtre_2),float(fs_filtre_1), - float(cutoff_filtre_1),int(order_filtre_1),int(windowsize_filtre_1),float(fs_filtre_2),float(cutoff_filtre_2),int(order_filtre_2),int(windowsize_filtre_2),) - fig2 = px.imshow( data_bscan, color_continuous_scale="Jet", diff --git a/3D_app/pages/gng.py b/3D_app/pages/gng.py index 4563179..4150045 100644 --- a/3D_app/pages/gng.py +++ b/3D_app/pages/gng.py @@ -1,11 +1,14 @@ +import asyncio import dash import plotly.graph_objects as go -from dash import html, dcc, callback, Input, Output +from dash import html, dcc, callback, Input, Output, State import dash_bootstrap_components as dbc from gng2 import GrowingNeuralGas from sklearn import datasets as sk -dash.register_page(__name__, path="/gng", title="GNG", name="GNG") +dash.register_page( + __name__, path="/gng", title="GNG", name="GNG", description="Growing Neural Gas" +) # Generate synthetic data X, Y = sk.make_moons(n_samples=200, noise=0.1, random_state=0) @@ -63,22 +66,47 @@ layout = html.Div( ] ) +@callback( + Output("base-graph", "figure", allow_duplicate=True), + [Input("noise", "value")], + prevent_initial_call=True, +) +def update_base_graph(noise): + X, Y = sk.make_moons(n_samples=200, noise=noise, random_state=0) + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=X[Y == 0, 0], y=X[Y == 0, 1], mode="markers", marker=dict(color="blue") + ) + ) + fig.add_trace( + go.Scatter( + x=X[Y == 1, 0], y=X[Y == 1, 1], mode="markers", marker=dict(color="red") + ) + ) + fig.update_layout( + showlegend=False, + margin=dict(l=0, r=0, t=0, b=0), + xaxis=dict(visible=False), + yaxis=dict(visible=False), + title="Base Data", + ) + + return fig + @callback( [Output("gng-graph", "figure"), Output("base-graph", "figure")], [ Input("generate-gng", "n_clicks"), - Input("noise", "value"), - Input("iterations", "value"), - Input("nodes", "value"), ], + [State("noise", "value"), State("iterations", "value"), State("nodes", "value")], ) def generate_gng(n_clicks, noise, iterations, nodes): global clics X, Y = sk.make_moons(n_samples=200, noise=noise, random_state=0) - print(X, Y) - fig2 = go.Figure() fig2.add_trace( go.Scatter( @@ -100,7 +128,7 @@ def generate_gng(n_clicks, noise, iterations, nodes): if n_clicks != clics: gng = GrowingNeuralGas(input_dim=2, max_nodes=nodes) - gng.fit(X, num_iterations=iterations) + asyncio.run(gng.fit(X, num_iterations=iterations)) fig = go.Figure() for edge in gng.edges: diff --git a/3D_app/pages/home.py b/3D_app/pages/home.py index 80b0b2d..786cd89 100644 --- a/3D_app/pages/home.py +++ b/3D_app/pages/home.py @@ -9,7 +9,7 @@ from util import * from os.path import join import diskcache -dash.register_page(__name__, path="/") +dash.register_page(__name__, path="/", description="The home page of the web app") # on définit le dossier et les fichiers à lire dossier = "Dataset/Shear_transform" @@ -263,7 +263,7 @@ Ascan_card = dbc.Fade( dcc.Slider( id="layer-slider-ascan-fullscreen", min=0, - max=dim_x, + max=dim_x - 1, value=0, step=1, marks={ @@ -345,7 +345,7 @@ Bscan_card_xy = dbc.Fade( dcc.Slider( id="layer-slider-bscan-zx-fullscreen", min=0, - max=dim_x, + max=dim_x - 1, value=0, step=1, marks={ @@ -470,16 +470,13 @@ layout = html.Div( # on défini les callbacks # callback pour le plot 3D @callback( - [Output("3dplot", "figure")], + [Output("3dplot", "figure"), Output("fade-3dplot", "is_in")], [ Input("iso-slider", "value"), Input("y-slider", "value"), Input("store-settings", "data"), ], [dash.dependencies.State("fade-3dplot", "is_in")], - running=[ - (Output("fade-3dplot", "is_in"), False, True), - ], ) def update_3dplot(iso_value, y_values, settings, is_in): if settings["use_real_values"]: @@ -511,7 +508,7 @@ def update_3dplot(iso_value, y_values, settings, is_in): ) ) - return [fig] + return [fig, True] # callback pour le plot 3D en plein écran @@ -547,18 +544,15 @@ def update_3dplot_fullscreen(iso_value, y_values): # callback pour le A-scan @callback( - [Output("heatmap-ascan", "figure")], + [Output("heatmap-ascan", "figure"), Output("fade-ascan", "is_in")], [Input("layer-slider-bscan-zx", "value"), Input("layer-slider-bscan-xy", "value")], [dash.dependencies.State("fade-ascan", "is_in")], - running=[ - (Output("fade-ascan", "is_in"), False, True), - ], prevent_initial_call=True, ) def update_heatmap_ascan(layer, layer1, is_in): fig = px.line(y=volume[layer - 1, :, layer1], title="A-scan") - return [fig] + return [fig, True] # callback pour le A-scan en plein écran @@ -576,12 +570,10 @@ def update_heatmap_ascan_fullscreen(layer): [ Output("heatmap-bscan-zx", "figure"), Output("store-bscan-zx-layer", "data"), + Output("fade-bscan-xy", "is_in"), ], [Input("layer-slider-bscan-zx", "value")], [dash.dependencies.State("fade-bscan-zx", "is_in")], - running=[ - (Output("fade-bscan-xy", "is_in"), False, True), - ], prevent_initial_call=True, ) def update_heatmap_bscan_zx(layer, is_in): @@ -592,7 +584,7 @@ def update_heatmap_bscan_zx(layer, is_in): title="B-scan ZX", ) - return [fig, layer] + return [fig, layer, True] # callback pour les B-scan ZX en plein écran @@ -613,19 +605,20 @@ def update_heatmap_bscan_zx_fullscreen(layer): # callback pour les B-scan ZX @callback( - [Output("heatmap-bscan-xy", "figure"), Output("store-bscan-xy-layer", "data")], + [ + Output("heatmap-bscan-xy", "figure"), + Output("store-bscan-xy-layer", "data"), + Output("fade-bscan-zx", "is_in"), + ], [Input("layer-slider-bscan-xy", "value")], [dash.dependencies.State("fade-bscan-xy", "is_in")], - running=[ - (Output("fade-bscan-zx", "is_in"), False, True), - ], prevent_initial_call=True, ) def update_heatmap_bscan_xy(layer, is_in): fig = go.Figure(data=go.Heatmap(z=volume[:, :, layer], colorscale="Jet")) fig.update_layout(title="B-scan XY") - return [fig, layer] + return [fig, layer, True] # callback pour les B-scan ZX en plein écran @@ -830,7 +823,7 @@ def update_settings( prevent_initial_call=True, ) def redef_data(data): - global volume, dim_x, dim_y, dim_z, X, Y, Z + global volume, pre_volume, dim_x, dim_y, dim_z, X, Y, Z volume = pre_volume[ :: data["echantillonage_x"], :: data["echantillonage_y"], @@ -914,30 +907,41 @@ def apply_filters(data, settings): @callback( Input("store-files", "data"), - State("store-settings", "data"), + [ + State("store-settings", "data"), + State("layer-slider-bscan-zx", "value"), + State("layer-slider-bscan-xy", "value"), + State("iso-slider", "value"), + State("y-slider", "value"), + ], prevent_initial_call=True, ) -def update_files(data, settings): +def update_files(data, settings, layer_zx, layer_xy, iso, y): global pre_volume, dim_y if data is None or data == "": return None pre_volume = np.load(join("Dataset/saves", data)) redef_data(settings) - update_3dplot(0, [0, dim_y / 2], settings, False) - update_heatmap_ascan(0, 0, False) - update_heatmap_bscan_zx(0, False) - update_heatmap_bscan_xy(0, False) + update_3dplot(iso, y, settings, False) + update_heatmap_ascan(layer_zx, layer_xy, False) + update_heatmap_bscan_zx(layer_zx, False) + update_heatmap_bscan_xy(layer_xy, False) return None + @callback( + Output("save-toast", "is_open"), Input("save-save", "n_clicks"), - [State("save-input", "value"), State("save-format", "value")], + [ + State("save-input", "value"), + State("save-format", "value"), + ], ) def save_data(n_clicks, filename, format): if n_clicks is None: - return None + return False if format == "raw": np.save(join("Dataset/saves", filename), pre_volume) else: np.save(join("Dataset/saves", filename), volume) - return None \ No newline at end of file + return True