[lab4] GUI

This commit is contained in:
Andrew Golovashevich 2026-02-09 11:41:31 +03:00
parent 7071338f96
commit 8680a370a5
4 changed files with 299 additions and 107 deletions

View File

@ -1,7 +1,7 @@
use crate::algo::net::Net; use crate::algo::net::Net;
use rand::Rng; use rand::Rng;
struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize> pub struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
where where
[(); { W * H }]:, [(); { W * H }]:,
{ {
@ -75,6 +75,10 @@ where
self.net.compute(); self.net.compute();
return self.net.output_data; return self.net.output_data;
} }
pub fn set_input_from_etalon(&mut self, etalon: usize) {
assert!(etalon < self.etalons.len());
self.input = self.etalons[etalon];
}
} }
struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize> struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize>

View File

@ -3,8 +3,8 @@ use std::array;
fn image<const H: usize, const W: usize>(raw: [&str; H]) -> [[bool; W]; H] { fn image<const H: usize, const W: usize>(raw: [&str; H]) -> [[bool; W]; H] {
return raw.map(|r| { return raw.map(|r| {
let mut it = r.chars().map(|c| match c { let mut it = r.chars().map(|c| match c {
'+' => true, '*' => true,
'0' => false, ' ' => false,
_ => panic!(), _ => panic!(),
}); });
@ -12,8 +12,10 @@ fn image<const H: usize, const W: usize>(raw: [&str; H]) -> [[bool; W]; H] {
}); });
} }
pub fn gen_images() -> [[[bool; 7]; 7]; 8] { pub fn gen_images() -> ([&'static str; 8], [[[bool; 7]; 7]; 8]) {
return [ return (
["<", "", ">", "", "=", "", "", ""],
[
// < // <
image( image(
[ [
@ -111,4 +113,5 @@ pub fn gen_images() -> [[[bool; 7]; 7]; 8] {
] ]
), ),
] ]
)
} }

View File

@ -1,5 +1,9 @@
mod comparation_operators_model;
mod compute; mod compute;
mod fix; mod fix;
mod net;
mod comparation_operators_model;
mod images; mod images;
mod net;
pub use comparation_operators_model::ComparationOperatorsModel;
pub use images::gen_images;

View File

@ -1,3 +1,184 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
mod algo; mod algo;
fn main() {} use crate::algo::{ComparationOperatorsModel, gen_images};
use eframe::egui;
use eframe::egui::{Ui, Widget};
use eframe::emath::Numeric;
use egui_extras::{Column, TableBuilder};
use std::cmp::min;
use std::ops::RangeInclusive;
fn main() -> eframe::Result {
let options = eframe::NativeOptions {
viewport: egui::ViewportBuilder::default().with_inner_size([640.0, 400.0]),
..Default::default()
};
eframe::run_native(
"Neural net",
options,
Box::new(|_cc| Ok(Box::<MyApp>::default())),
)
}
enum TrainingState {
NoTrain,
Training {
epochs_left: usize,
epochs_per_update: usize,
n: f64,
},
}
struct MyApp {
model: ComparationOperatorsModel<7, 7, 8>,
training: TrainingState,
n: f64,
epochs_count: usize,
symbols: [&'static str; 8],
compute_result: [f64; 8],
}
impl Default for MyApp {
fn default() -> Self {
let imgs = gen_images();
return Self {
model: ComparationOperatorsModel::new(imgs.1),
training: TrainingState::NoTrain,
n: 0.1,
epochs_count: 1,
symbols: imgs.0,
compute_result: [0.0; 8],
};
}
}
fn _slider<T: Numeric>(
ui: &mut Ui,
name: &str,
storage: &mut T,
range: RangeInclusive<T>,
step: f64,
) {
let label = ui.label(name);
ui.scope(|ui| {
ui.spacing_mut().slider_width = ui.available_width()
- ui.spacing().interact_size.x
- ui.spacing().button_padding.x * 2.0;
ui.add(egui::Slider::new(storage, range).step_by(step))
.labelled_by(label.id);
});
}
impl eframe::App for MyApp {
fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) {
egui::CentralPanel::default().show(ui, |ui| {
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
_slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001);
ui.label("");
_slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64);
ui.label("");
ui.horizontal(|ui| {
if ui.button("Train").clicked() {
self.training = TrainingState::Training {
epochs_left: self.epochs_count,
epochs_per_update: min(10, self.epochs_count / 60),
n: 0.0,
}
}
match &mut self.training {
TrainingState::NoTrain => {
egui::ProgressBar::new(1.0).ui(ui);
}
TrainingState::Training {
epochs_left,
epochs_per_update,
n,
} => {
if *epochs_left >= *epochs_per_update {
for _ in 0..*epochs_per_update {
self.model.train_epoch(*n);
}
*epochs_left -= *epochs_per_update;
} else {
for _ in 0..*epochs_left {
self.model.train_epoch(*n);
}
*epochs_left = 0;
}
ui.add_enabled_ui(true, |ui| {
egui::ProgressBar::new(
1.0 - (*epochs_left as f32) / (self.epochs_count as f32),
)
.ui(ui)
});
if 0 == *epochs_left {
self.training = TrainingState::NoTrain;
}
}
}
})
});
ui.label("");
ui.horizontal(|ui| {
for (i, s) in self.symbols.iter().enumerate() {
if ui.button(s.to_owned()).clicked() {
self.model.set_input_from_etalon(i);
}
}
});
ui.label("");
for y in 0..7 {
ui.horizontal(|ui| {
for x in 0..7 {
ui.checkbox(&mut self.model.input[y][x], "");
}
});
}
ui.label("");
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
if ui.button("Check").clicked() {
self.compute_result = self.model.calculate_predictions();
}
});
TableBuilder::new(ui)
.striped(true) // Alternating row colors
.resizable(true)
.column(Column::remainder())
.column(Column::remainder())
.header(20.0, |mut header| {
header.col(|ui| {
ui.label("Class");
});
header.col(|ui| {
ui.label("Probability");
});
})
.body(|body| {
body.rows(20.0, self.symbols.len(), |mut row| {
let i = row.index();
row.col(|ui| {
ui.label(self.symbols[i]);
});
row.col(|ui| {
ui.label(self.compute_result[i].to_string());
});
});
});
});
}
}