From ef711b5e955bbf9ff36d28bc37b66fc7b395cb82 Mon Sep 17 00:00:00 2001 From: Andrew Golovashevich Date: Sun, 8 Feb 2026 21:40:19 +0300 Subject: [PATCH] [lab4] Model implementation --- lab4/src/algo/compute.rs | 8 ++--- lab4/src/algo/fix.rs | 22 ++++++------ lab4/src/algo/mod.rs | 1 + lab4/src/algo/net.rs | 75 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 15 deletions(-) create mode 100644 lab4/src/algo/net.rs diff --git a/lab4/src/algo/compute.rs b/lab4/src/algo/compute.rs index 9c9fc77..1f92439 100644 --- a/lab4/src/algo/compute.rs +++ b/lab4/src/algo/compute.rs @@ -1,8 +1,8 @@ use std::ops::Index; -fn compute_potential>( - weights: &[WA], +pub(super) fn compute>( input_data: &[f64], + weights: &[WA], potential_data: &mut [f64], output_data: &mut [f64], f: impl Fn(f64) -> f64, @@ -28,7 +28,7 @@ fn test_compiles_static() { let mut potential_data = [0.0, 0.0]; let mut output_data = [0.0, 0.0]; - compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x); + compute(&input_data,&weights, &mut potential_data, &mut output_data, |x| x); } #[test] @@ -38,5 +38,5 @@ fn test_compiles_dynamic() { let mut potential_data = [0.0, 0.0]; let mut output_data = [0.0, 0.0]; - compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x); + compute(&input_data, &weights, &mut potential_data, &mut output_data, |x| x); } \ No newline at end of file diff --git a/lab4/src/algo/fix.rs b/lab4/src/algo/fix.rs index 0c28921..948f09b 100644 --- a/lab4/src/algo/fix.rs +++ b/lab4/src/algo/fix.rs @@ -1,12 +1,12 @@ use std::ops::{Index, IndexMut}; -fn calc_error>( +pub(super) fn calc_error>( next_errors: &[f64], - weights: &[NWA], + next_weights: &[NWA], current_errors: &mut [f64], ) { for i in 0..current_errors.len() { - current_errors[i] = weights + current_errors[i] = next_weights .iter() .enumerate() .map(|(j, ww)| ww[i] * next_errors[j]) @@ -14,19 +14,19 @@ fn calc_error>( } } -fn apply_error>( +pub(super) fn apply_error>( n: f64, - errors: &[f64], - weights: &mut [NWA], - current_potentials: &[f64], + next_errors: &[f64], + next_weights: &mut [NWA], + current_outputs: &[f64], next_potentials: &[f64], - f: impl Fn(f64) -> f64, f1: impl Fn(f64) -> f64, ) { - for i in 0..current_potentials.len() { + for i in 0..current_outputs.len() { for j in 0..next_potentials.len() { - let dw = n * errors[j] * f1(next_potentials[j]) * f(current_potentials[i]); - weights[j][i] += dw; + let dw = n * next_errors[j] * f1(next_potentials[j]) * current_outputs[i]; + next_weights[j][i] += dw; } } } + diff --git a/lab4/src/algo/mod.rs b/lab4/src/algo/mod.rs index 0a290fc..a89eddd 100644 --- a/lab4/src/algo/mod.rs +++ b/lab4/src/algo/mod.rs @@ -1,2 +1,3 @@ mod compute; mod fix; +mod net; diff --git a/lab4/src/algo/net.rs b/lab4/src/algo/net.rs new file mode 100644 index 0000000..9b1ec8e --- /dev/null +++ b/lab4/src/algo/net.rs @@ -0,0 +1,75 @@ +use crate::algo::compute::compute; +use crate::algo::fix::{apply_error, calc_error}; +use std::ops::{Div, Mul}; + +struct Net { + inner_layer_size: usize, + pub input_data: [f64; InputLayerSize], + i_to_h_weights: Vec<[f64; InputLayerSize]>, + hidden_potentials: Vec, + hidden_data: Vec, + h_to_o_weights: [Vec; OutputLayerSize], + output_potentials: [f64; OutputLayerSize], + pub output_data: [f64; OutputLayerSize], + output_errors: [f64; OutputLayerSize], + hidden_errors: Vec, +} + +impl + Net +{ + fn _sigmoid(x: f64) -> f64 { + return 1.0.div(1.0 + (-x).exp()); + } + + fn _sigmoidDerivative(x: f64) -> f64 { + return x.mul(1.0 - x); + } + + pub fn compute(&mut self) { + compute( + &self.input_data, + self.i_to_h_weights.as_slice(), + self.hidden_potentials.as_mut_slice(), + self.hidden_data.as_mut_slice(), + Self::_sigmoid, + ); + + compute( + self.hidden_data.as_slice(), + &self.h_to_o_weights, + &mut self.output_potentials, + &mut self.output_data, + Self::_sigmoid, + ) + } + + pub fn fix(&mut self, expected: &[f64; OutputLayerSize], n: f64) { + for (i, (a, e)) in self.output_data.iter().zip(expected).enumerate() { + self.output_errors[i] = e - a; + } + + calc_error( + &self.output_errors, + &self.h_to_o_weights, + self.hidden_errors.as_mut_slice(), + ); + + apply_error( + n, + &self.output_errors, + &mut self.h_to_o_weights, + self.hidden_data.as_slice(), + &self.output_potentials, + Self::_sigmoidDerivative + ); + apply_error( + n, + self.hidden_errors.as_slice(), + self.i_to_h_weights.as_mut_slice(), + &self.input_data, + self.hidden_potentials.as_slice(), + Self::_sigmoidDerivative + ); + } +}