[lab4] GUI

This commit is contained in:
Andrew Golovashevich 2026-02-09 11:41:31 +03:00
parent 7071338f96
commit 8680a370a5
4 changed files with 299 additions and 107 deletions

View File

@ -1,7 +1,7 @@
use crate::algo::net::Net;
use rand::Rng;
struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
pub struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
where
[(); { W * H }]:,
{
@ -75,6 +75,10 @@ where
self.net.compute();
return self.net.output_data;
}
pub fn set_input_from_etalon(&mut self, etalon: usize) {
assert!(etalon < self.etalons.len());
self.input = self.etalons[etalon];
}
}
struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize>

View File

@ -3,112 +3,115 @@ use std::array;
fn image<const H: usize, const W: usize>(raw: [&str; H]) -> [[bool; W]; H] {
return raw.map(|r| {
let mut it = r.chars().map(|c| match c {
'+' => true,
'0' => false,
_ => panic!(),
});
'*' => true,
' ' => false,
_ => panic!(),
});
return array::from_fn(|_| it.next().unwrap())
});
}
pub fn gen_images() -> [[[bool; 7]; 7]; 8] {
return [
// <
image(
[
" ",
" * ",
" * ",
" * ",
" * ",
" * ",
" ",
]
),
// <=
image(
[
" * ",
" * ",
" * ",
" * ",
" * * ",
" * ",
" * ",
]
),
// >
image(
[
" ",
" * ",
" * ",
" * ",
" * ",
" * ",
" ",
]
),
// >=
image(
[
" * ",
" * ",
" * ",
" * ",
" * * ",
" * ",
" * ",
]
),
// =
image(
[
" ",
" ",
" ***** ",
" ",
" ***** ",
" ",
" ",
]
),
// !=
image(
[
" ",
" * ",
" ***** ",
" * ",
" ***** ",
" * ",
" ",
]
),
// ≡
image(
[
" ",
" ***** ",
" ",
" ***** ",
" ",
" ***** ",
" ",
]
),
// ≈
image(
[
" * ",
" * * * ",
" * ",
" ",
" * ",
" * * * ",
" * ",
]
),
]
pub fn gen_images() -> ([&'static str; 8], [[[bool; 7]; 7]; 8]) {
return (
["<", "", ">", "", "=", "", "", ""],
[
// <
image(
[
" ",
" * ",
" * ",
" * ",
" * ",
" * ",
" ",
]
),
// <=
image(
[
" * ",
" * ",
" * ",
" * ",
" * * ",
" * ",
" * ",
]
),
// >
image(
[
" ",
" * ",
" * ",
" * ",
" * ",
" * ",
" ",
]
),
// >=
image(
[
" * ",
" * ",
" * ",
" * ",
" * * ",
" * ",
" * ",
]
),
// =
image(
[
" ",
" ",
" ***** ",
" ",
" ***** ",
" ",
" ",
]
),
// !=
image(
[
" ",
" * ",
" ***** ",
" * ",
" ***** ",
" * ",
" ",
]
),
// ≡
image(
[
" ",
" ***** ",
" ",
" ***** ",
" ",
" ***** ",
" ",
]
),
// ≈
image(
[
" * ",
" * * * ",
" * ",
" ",
" * ",
" * * * ",
" * ",
]
),
]
)
}

View File

@ -1,5 +1,9 @@
mod comparation_operators_model;
mod compute;
mod fix;
mod net;
mod comparation_operators_model;
mod images;
mod net;
pub use comparation_operators_model::ComparationOperatorsModel;
pub use images::gen_images;

View File

@ -1,3 +1,184 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
mod algo;
fn main() {}
use crate::algo::{ComparationOperatorsModel, gen_images};
use eframe::egui;
use eframe::egui::{Ui, Widget};
use eframe::emath::Numeric;
use egui_extras::{Column, TableBuilder};
use std::cmp::min;
use std::ops::RangeInclusive;
fn main() -> eframe::Result {
let options = eframe::NativeOptions {
viewport: egui::ViewportBuilder::default().with_inner_size([640.0, 400.0]),
..Default::default()
};
eframe::run_native(
"Neural net",
options,
Box::new(|_cc| Ok(Box::<MyApp>::default())),
)
}
enum TrainingState {
NoTrain,
Training {
epochs_left: usize,
epochs_per_update: usize,
n: f64,
},
}
struct MyApp {
model: ComparationOperatorsModel<7, 7, 8>,
training: TrainingState,
n: f64,
epochs_count: usize,
symbols: [&'static str; 8],
compute_result: [f64; 8],
}
impl Default for MyApp {
fn default() -> Self {
let imgs = gen_images();
return Self {
model: ComparationOperatorsModel::new(imgs.1),
training: TrainingState::NoTrain,
n: 0.1,
epochs_count: 1,
symbols: imgs.0,
compute_result: [0.0; 8],
};
}
}
fn _slider<T: Numeric>(
ui: &mut Ui,
name: &str,
storage: &mut T,
range: RangeInclusive<T>,
step: f64,
) {
let label = ui.label(name);
ui.scope(|ui| {
ui.spacing_mut().slider_width = ui.available_width()
- ui.spacing().interact_size.x
- ui.spacing().button_padding.x * 2.0;
ui.add(egui::Slider::new(storage, range).step_by(step))
.labelled_by(label.id);
});
}
impl eframe::App for MyApp {
fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) {
egui::CentralPanel::default().show(ui, |ui| {
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
_slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001);
ui.label("");
_slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64);
ui.label("");
ui.horizontal(|ui| {
if ui.button("Train").clicked() {
self.training = TrainingState::Training {
epochs_left: self.epochs_count,
epochs_per_update: min(10, self.epochs_count / 60),
n: 0.0,
}
}
match &mut self.training {
TrainingState::NoTrain => {
egui::ProgressBar::new(1.0).ui(ui);
}
TrainingState::Training {
epochs_left,
epochs_per_update,
n,
} => {
if *epochs_left >= *epochs_per_update {
for _ in 0..*epochs_per_update {
self.model.train_epoch(*n);
}
*epochs_left -= *epochs_per_update;
} else {
for _ in 0..*epochs_left {
self.model.train_epoch(*n);
}
*epochs_left = 0;
}
ui.add_enabled_ui(true, |ui| {
egui::ProgressBar::new(
1.0 - (*epochs_left as f32) / (self.epochs_count as f32),
)
.ui(ui)
});
if 0 == *epochs_left {
self.training = TrainingState::NoTrain;
}
}
}
})
});
ui.label("");
ui.horizontal(|ui| {
for (i, s) in self.symbols.iter().enumerate() {
if ui.button(s.to_owned()).clicked() {
self.model.set_input_from_etalon(i);
}
}
});
ui.label("");
for y in 0..7 {
ui.horizontal(|ui| {
for x in 0..7 {
ui.checkbox(&mut self.model.input[y][x], "");
}
});
}
ui.label("");
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
if ui.button("Check").clicked() {
self.compute_result = self.model.calculate_predictions();
}
});
TableBuilder::new(ui)
.striped(true) // Alternating row colors
.resizable(true)
.column(Column::remainder())
.column(Column::remainder())
.header(20.0, |mut header| {
header.col(|ui| {
ui.label("Class");
});
header.col(|ui| {
ui.label("Probability");
});
})
.body(|body| {
body.rows(20.0, self.symbols.len(), |mut row| {
let i = row.index();
row.col(|ui| {
ui.label(self.symbols[i]);
});
row.col(|ui| {
ui.label(self.compute_result[i].to_string());
});
});
});
});
}
}