[lab4] Model wrapper to work with domain
This commit is contained in:
parent
ef711b5e95
commit
7071338f96
104
lab4/src/algo/comparation_operators_model.rs
Normal file
104
lab4/src/algo/comparation_operators_model.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
use crate::algo::net::Net;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
|
||||||
|
where
|
||||||
|
[(); { W * H }]:,
|
||||||
|
{
|
||||||
|
net: Net<{ W * H }, ClassCount>,
|
||||||
|
pub input: [[bool; W]; H],
|
||||||
|
etalons: [[[bool; W]; H]; ClassCount],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const W: usize, const H: usize, const ClassCount: usize>
|
||||||
|
ComparationOperatorsModel<W, H, ClassCount>
|
||||||
|
where
|
||||||
|
[(); { W * H }]:,
|
||||||
|
{
|
||||||
|
pub fn new(etalons: [[[bool; W]; H]; ClassCount]) -> Self {
|
||||||
|
return Self {
|
||||||
|
net: Net::new(),
|
||||||
|
input: [[false; W]; H],
|
||||||
|
etalons,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_inputs(net: &mut Net<{ W * H }, ClassCount>, data: &[[bool; W]; H]) {
|
||||||
|
let mut i = 0;
|
||||||
|
for y in 0..H {
|
||||||
|
for x in 0..W {
|
||||||
|
net.input_data[i] = if data[y][x] { 1.0 } else { 0.0 };
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train_one(&mut self, n: f64, image: usize) {
|
||||||
|
assert!(image < self.etalons.len());
|
||||||
|
|
||||||
|
Self::set_inputs(&mut self.net, &self.etalons[image]);
|
||||||
|
|
||||||
|
self.net.compute();
|
||||||
|
|
||||||
|
let mut expected = [0.0; ClassCount];
|
||||||
|
expected[image] = 1.0;
|
||||||
|
|
||||||
|
self.net.fix(&expected, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_epoch(&mut self, n: f64) {
|
||||||
|
for i in 0..ClassCount {
|
||||||
|
self.train_one(n, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&mut self, n: f64, epochs: usize) -> impl Iterator<Item = ()> {
|
||||||
|
return TrainIterator {
|
||||||
|
owner: self,
|
||||||
|
n: n,
|
||||||
|
epochs_left: epochs,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
pub fn hidden_layer_size(&self) -> usize {
|
||||||
|
return self.net.hidden_layer_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resize_hidden_layer(&mut self, new_size: usize) {
|
||||||
|
self.net.resize_hidden_layer(new_size);
|
||||||
|
}
|
||||||
|
pub fn set_random_weights(&mut self, rng: &mut impl Rng) {
|
||||||
|
self.net.set_random_weights(rng)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn calculate_predictions(&mut self) -> [f64; ClassCount] {
|
||||||
|
Self::set_inputs(&mut self.net, &self.input);
|
||||||
|
self.net.compute();
|
||||||
|
return self.net.output_data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize>
|
||||||
|
where
|
||||||
|
[(); { W * H }]:,
|
||||||
|
{
|
||||||
|
owner: &'a mut ComparationOperatorsModel<W, H, ClassCount>,
|
||||||
|
n: f64,
|
||||||
|
epochs_left: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, const W: usize, const H: usize, const ClassCount: usize> Iterator
|
||||||
|
for TrainIterator<'_, W, H, ClassCount>
|
||||||
|
where
|
||||||
|
[(); { W * H }]:,
|
||||||
|
{
|
||||||
|
type Item = ();
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.epochs_left == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.owner.train_epoch(self.n);
|
||||||
|
return Some(());
|
||||||
|
}
|
||||||
|
}
|
||||||
114
lab4/src/algo/images.rs
Normal file
114
lab4/src/algo/images.rs
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
use std::array;
|
||||||
|
|
||||||
|
fn image<const H: usize, const W: usize>(raw: [&str; H]) -> [[bool; W]; H] {
|
||||||
|
return raw.map(|r| {
|
||||||
|
let mut it = r.chars().map(|c| match c {
|
||||||
|
'+' => true,
|
||||||
|
'0' => false,
|
||||||
|
_ => panic!(),
|
||||||
|
});
|
||||||
|
|
||||||
|
return array::from_fn(|_| it.next().unwrap())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gen_images() -> [[[bool; 7]; 7]; 8] {
|
||||||
|
return [
|
||||||
|
// <
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// <=
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// >
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// >=
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
" * * ",
|
||||||
|
" * ",
|
||||||
|
" * ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// =
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
" ***** ",
|
||||||
|
" ",
|
||||||
|
" ***** ",
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// !=
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
" * ",
|
||||||
|
" ***** ",
|
||||||
|
" * ",
|
||||||
|
" ***** ",
|
||||||
|
" * ",
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// ≡
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
" ***** ",
|
||||||
|
" ",
|
||||||
|
" ***** ",
|
||||||
|
" ",
|
||||||
|
" ***** ",
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
// ≈
|
||||||
|
image(
|
||||||
|
[
|
||||||
|
" * ",
|
||||||
|
" * * * ",
|
||||||
|
" * ",
|
||||||
|
" ",
|
||||||
|
" * ",
|
||||||
|
" * * * ",
|
||||||
|
" * ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
@ -1,3 +1,5 @@
|
|||||||
mod compute;
|
mod compute;
|
||||||
mod fix;
|
mod fix;
|
||||||
mod net;
|
mod net;
|
||||||
|
mod comparation_operators_model;
|
||||||
|
mod images;
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
use std::array;
|
||||||
use crate::algo::compute::compute;
|
use crate::algo::compute::compute;
|
||||||
use crate::algo::fix::{apply_error, calc_error};
|
use crate::algo::fix::{apply_error, calc_error};
|
||||||
|
use rand::Rng;
|
||||||
use std::ops::{Div, Mul};
|
use std::ops::{Div, Mul};
|
||||||
|
|
||||||
struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
|
pub(super) struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
|
||||||
inner_layer_size: usize,
|
inner_layer_size: usize,
|
||||||
pub input_data: [f64; InputLayerSize],
|
pub input_data: [f64; InputLayerSize],
|
||||||
i_to_h_weights: Vec<[f64; InputLayerSize]>,
|
i_to_h_weights: Vec<[f64; InputLayerSize]>,
|
||||||
@ -18,6 +20,22 @@ struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
|
|||||||
impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
||||||
Net<InputLayerSize, OutputLayerSize>
|
Net<InputLayerSize, OutputLayerSize>
|
||||||
{
|
{
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
return Self {
|
||||||
|
inner_layer_size: 1,
|
||||||
|
input_data: [0.0; InputLayerSize],
|
||||||
|
i_to_h_weights: Vec::from([[0.0; InputLayerSize]]),
|
||||||
|
hidden_potentials: Vec::from([0.0]),
|
||||||
|
hidden_data: Vec::from([0.0]),
|
||||||
|
h_to_o_weights: array::from_fn(|_| Vec::from([0.0])),
|
||||||
|
output_potentials: [0.0; OutputLayerSize],
|
||||||
|
output_data: [0.0; OutputLayerSize],
|
||||||
|
output_errors: [0.0; OutputLayerSize],
|
||||||
|
hidden_errors: Vec::from([0.0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
fn _sigmoid(x: f64) -> f64 {
|
fn _sigmoid(x: f64) -> f64 {
|
||||||
return 1.0.div(1.0 + (-x).exp());
|
return 1.0.div(1.0 + (-x).exp());
|
||||||
}
|
}
|
||||||
@ -61,7 +79,7 @@ impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
|||||||
&mut self.h_to_o_weights,
|
&mut self.h_to_o_weights,
|
||||||
self.hidden_data.as_slice(),
|
self.hidden_data.as_slice(),
|
||||||
&self.output_potentials,
|
&self.output_potentials,
|
||||||
Self::_sigmoidDerivative
|
Self::_sigmoidDerivative,
|
||||||
);
|
);
|
||||||
apply_error(
|
apply_error(
|
||||||
n,
|
n,
|
||||||
@ -69,7 +87,36 @@ impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
|||||||
self.i_to_h_weights.as_mut_slice(),
|
self.i_to_h_weights.as_mut_slice(),
|
||||||
&self.input_data,
|
&self.input_data,
|
||||||
self.hidden_potentials.as_slice(),
|
self.hidden_potentials.as_slice(),
|
||||||
Self::_sigmoidDerivative
|
Self::_sigmoidDerivative,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn hidden_layer_size(&self) -> usize {
|
||||||
|
return self.inner_layer_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resize_hidden_layer(&mut self, new_size: usize) {
|
||||||
|
assert!(new_size > 0);
|
||||||
|
self.inner_layer_size = new_size;
|
||||||
|
self.i_to_h_weights.resize(new_size, [0.0; InputLayerSize]);
|
||||||
|
self.hidden_potentials.resize(new_size, 0.0);
|
||||||
|
self.hidden_data.resize(new_size, 0.0);
|
||||||
|
for w in self.h_to_o_weights.iter_mut() {
|
||||||
|
w.resize(new_size, 0.0);
|
||||||
|
}
|
||||||
|
self.hidden_errors.resize(new_size, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_random_weights(&mut self, rng: &mut impl Rng) {
|
||||||
|
for ww in self.i_to_h_weights.iter_mut() {
|
||||||
|
for w in ww.iter_mut() {
|
||||||
|
*w = rng.random()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for ww in self.h_to_o_weights.iter_mut() {
|
||||||
|
for w in ww.iter_mut() {
|
||||||
|
*w = rng.random()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user