From 8680a370a52e7f20ca39ada0f4343f9dcec81544 Mon Sep 17 00:00:00 2001 From: Andrew Golovashevich Date: Mon, 9 Feb 2026 11:41:31 +0300 Subject: [PATCH] [lab4] GUI --- lab4/src/algo/comparation_operators_model.rs | 6 +- lab4/src/algo/images.rs | 209 ++++++++++--------- lab4/src/algo/mod.rs | 8 +- lab4/src/main.rs | 183 +++++++++++++++- 4 files changed, 299 insertions(+), 107 deletions(-) diff --git a/lab4/src/algo/comparation_operators_model.rs b/lab4/src/algo/comparation_operators_model.rs index 3c60de1..2d58e42 100644 --- a/lab4/src/algo/comparation_operators_model.rs +++ b/lab4/src/algo/comparation_operators_model.rs @@ -1,7 +1,7 @@ use crate::algo::net::Net; use rand::Rng; -struct ComparationOperatorsModel +pub struct ComparationOperatorsModel where [(); { W * H }]:, { @@ -75,6 +75,10 @@ where self.net.compute(); return self.net.output_data; } + pub fn set_input_from_etalon(&mut self, etalon: usize) { + assert!(etalon < self.etalons.len()); + self.input = self.etalons[etalon]; + } } struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize> diff --git a/lab4/src/algo/images.rs b/lab4/src/algo/images.rs index e42ad96..d484ca3 100644 --- a/lab4/src/algo/images.rs +++ b/lab4/src/algo/images.rs @@ -3,112 +3,115 @@ use std::array; fn image(raw: [&str; H]) -> [[bool; W]; H] { return raw.map(|r| { let mut it = r.chars().map(|c| match c { - '+' => true, - '0' => false, - _ => panic!(), - }); + '*' => true, + ' ' => false, + _ => panic!(), + }); return array::from_fn(|_| it.next().unwrap()) }); } -pub fn gen_images() -> [[[bool; 7]; 7]; 8] { - return [ - // < - image( - [ - " ", - " * ", - " * ", - " * ", - " * ", - " * ", - " ", - ] - ), - // <= - image( - [ - " * ", - " * ", - " * ", - " * ", - " * * ", - " * ", - " * ", - ] - ), - // > - image( - [ - " ", - " * ", - " * ", - " * ", - " * ", - " * ", - " ", - ] - ), - // >= - image( - [ - " * ", - " * ", - " * ", - " * ", - " * * ", - " * ", - " * ", - ] - ), - // = - image( - [ - " ", - " ", - " ***** ", - " ", - " ***** ", - " ", - " ", - ] - ), - // != - image( - [ - " ", - " * ", - " ***** ", - " * ", - " ***** ", - " * ", - " ", - ] - ), - // ≡ - image( - [ - " ", - " ***** ", - " ", - " ***** ", - " ", - " ***** ", - " ", - ] - ), - // ≈ - image( - [ - " * ", - " * * * ", - " * ", - " ", - " * ", - " * * * ", - " * ", - ] - ), - ] +pub fn gen_images() -> ([&'static str; 8], [[[bool; 7]; 7]; 8]) { + return ( + ["<", "≤", ">", "≥", "=", "≠", "≡", "≈"], + [ + // < + image( + [ + " ", + " * ", + " * ", + " * ", + " * ", + " * ", + " ", + ] + ), + // <= + image( + [ + " * ", + " * ", + " * ", + " * ", + " * * ", + " * ", + " * ", + ] + ), + // > + image( + [ + " ", + " * ", + " * ", + " * ", + " * ", + " * ", + " ", + ] + ), + // >= + image( + [ + " * ", + " * ", + " * ", + " * ", + " * * ", + " * ", + " * ", + ] + ), + // = + image( + [ + " ", + " ", + " ***** ", + " ", + " ***** ", + " ", + " ", + ] + ), + // != + image( + [ + " ", + " * ", + " ***** ", + " * ", + " ***** ", + " * ", + " ", + ] + ), + // ≡ + image( + [ + " ", + " ***** ", + " ", + " ***** ", + " ", + " ***** ", + " ", + ] + ), + // ≈ + image( + [ + " * ", + " * * * ", + " * ", + " ", + " * ", + " * * * ", + " * ", + ] + ), + ] + ) } \ No newline at end of file diff --git a/lab4/src/algo/mod.rs b/lab4/src/algo/mod.rs index ed722e4..a48ac1b 100644 --- a/lab4/src/algo/mod.rs +++ b/lab4/src/algo/mod.rs @@ -1,5 +1,9 @@ +mod comparation_operators_model; mod compute; mod fix; -mod net; -mod comparation_operators_model; mod images; +mod net; + +pub use comparation_operators_model::ComparationOperatorsModel; + +pub use images::gen_images; diff --git a/lab4/src/main.rs b/lab4/src/main.rs index 552b8c2..4ece3c2 100644 --- a/lab4/src/main.rs +++ b/lab4/src/main.rs @@ -1,3 +1,184 @@ +#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] + mod algo; -fn main() {} \ No newline at end of file +use crate::algo::{ComparationOperatorsModel, gen_images}; +use eframe::egui; +use eframe::egui::{Ui, Widget}; +use eframe::emath::Numeric; +use egui_extras::{Column, TableBuilder}; +use std::cmp::min; +use std::ops::RangeInclusive; + +fn main() -> eframe::Result { + let options = eframe::NativeOptions { + viewport: egui::ViewportBuilder::default().with_inner_size([640.0, 400.0]), + ..Default::default() + }; + eframe::run_native( + "Neural net", + options, + Box::new(|_cc| Ok(Box::::default())), + ) +} + +enum TrainingState { + NoTrain, + Training { + epochs_left: usize, + epochs_per_update: usize, + n: f64, + }, +} + +struct MyApp { + model: ComparationOperatorsModel<7, 7, 8>, + training: TrainingState, + n: f64, + epochs_count: usize, + symbols: [&'static str; 8], + compute_result: [f64; 8], +} + +impl Default for MyApp { + fn default() -> Self { + let imgs = gen_images(); + return Self { + model: ComparationOperatorsModel::new(imgs.1), + training: TrainingState::NoTrain, + n: 0.1, + epochs_count: 1, + symbols: imgs.0, + compute_result: [0.0; 8], + }; + } +} + +fn _slider( + ui: &mut Ui, + name: &str, + storage: &mut T, + range: RangeInclusive, + step: f64, +) { + let label = ui.label(name); + + ui.scope(|ui| { + ui.spacing_mut().slider_width = ui.available_width() + - ui.spacing().interact_size.x + - ui.spacing().button_padding.x * 2.0; + ui.add(egui::Slider::new(storage, range).step_by(step)) + .labelled_by(label.id); + }); +} + +impl eframe::App for MyApp { + fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) { + egui::CentralPanel::default().show(ui, |ui| { + ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| { + _slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001); + ui.label(""); + _slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64); + ui.label(""); + + ui.horizontal(|ui| { + if ui.button("Train").clicked() { + self.training = TrainingState::Training { + epochs_left: self.epochs_count, + epochs_per_update: min(10, self.epochs_count / 60), + n: 0.0, + } + } + + match &mut self.training { + TrainingState::NoTrain => { + egui::ProgressBar::new(1.0).ui(ui); + } + TrainingState::Training { + epochs_left, + epochs_per_update, + n, + } => { + if *epochs_left >= *epochs_per_update { + for _ in 0..*epochs_per_update { + self.model.train_epoch(*n); + } + *epochs_left -= *epochs_per_update; + } else { + for _ in 0..*epochs_left { + self.model.train_epoch(*n); + } + *epochs_left = 0; + } + + ui.add_enabled_ui(true, |ui| { + egui::ProgressBar::new( + 1.0 - (*epochs_left as f32) / (self.epochs_count as f32), + ) + .ui(ui) + }); + + if 0 == *epochs_left { + self.training = TrainingState::NoTrain; + } + } + } + }) + }); + + ui.label(""); + + ui.horizontal(|ui| { + for (i, s) in self.symbols.iter().enumerate() { + if ui.button(s.to_owned()).clicked() { + self.model.set_input_from_etalon(i); + } + } + }); + + ui.label(""); + + for y in 0..7 { + ui.horizontal(|ui| { + for x in 0..7 { + ui.checkbox(&mut self.model.input[y][x], ""); + } + }); + } + + ui.label(""); + ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| { + if ui.button("Check").clicked() { + self.compute_result = self.model.calculate_predictions(); + } + }); + + TableBuilder::new(ui) + .striped(true) // Alternating row colors + .resizable(true) + .column(Column::remainder()) + .column(Column::remainder()) + .header(20.0, |mut header| { + header.col(|ui| { + ui.label("Class"); + }); + header.col(|ui| { + ui.label("Probability"); + }); + }) + .body(|body| { + body.rows(20.0, self.symbols.len(), |mut row| { + let i = row.index(); + row.col(|ui| { + ui.label(self.symbols[i]); + }); + row.col(|ui| { + ui.label(self.compute_result[i].to_string()); + }); + }); + }); + }); + } +}