[lab4] Model implementation
This commit is contained in:
parent
5c8e1f9dc2
commit
ef711b5e95
@ -1,8 +1,8 @@
|
|||||||
use std::ops::Index;
|
use std::ops::Index;
|
||||||
|
|
||||||
fn compute_potential<WA: Index<usize, Output = f64>>(
|
pub(super) fn compute<WA: Index<usize, Output = f64>>(
|
||||||
weights: &[WA],
|
|
||||||
input_data: &[f64],
|
input_data: &[f64],
|
||||||
|
weights: &[WA],
|
||||||
potential_data: &mut [f64],
|
potential_data: &mut [f64],
|
||||||
output_data: &mut [f64],
|
output_data: &mut [f64],
|
||||||
f: impl Fn(f64) -> f64,
|
f: impl Fn(f64) -> f64,
|
||||||
@ -28,7 +28,7 @@ fn test_compiles_static() {
|
|||||||
let mut potential_data = [0.0, 0.0];
|
let mut potential_data = [0.0, 0.0];
|
||||||
let mut output_data = [0.0, 0.0];
|
let mut output_data = [0.0, 0.0];
|
||||||
|
|
||||||
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
|
compute(&input_data,&weights, &mut potential_data, &mut output_data, |x| x);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -38,5 +38,5 @@ fn test_compiles_dynamic() {
|
|||||||
let mut potential_data = [0.0, 0.0];
|
let mut potential_data = [0.0, 0.0];
|
||||||
let mut output_data = [0.0, 0.0];
|
let mut output_data = [0.0, 0.0];
|
||||||
|
|
||||||
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
|
compute(&input_data, &weights, &mut potential_data, &mut output_data, |x| x);
|
||||||
}
|
}
|
||||||
@ -1,12 +1,12 @@
|
|||||||
use std::ops::{Index, IndexMut};
|
use std::ops::{Index, IndexMut};
|
||||||
|
|
||||||
fn calc_error<NWA: Index<usize, Output = f64>>(
|
pub(super) fn calc_error<NWA: Index<usize, Output = f64>>(
|
||||||
next_errors: &[f64],
|
next_errors: &[f64],
|
||||||
weights: &[NWA],
|
next_weights: &[NWA],
|
||||||
current_errors: &mut [f64],
|
current_errors: &mut [f64],
|
||||||
) {
|
) {
|
||||||
for i in 0..current_errors.len() {
|
for i in 0..current_errors.len() {
|
||||||
current_errors[i] = weights
|
current_errors[i] = next_weights
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(j, ww)| ww[i] * next_errors[j])
|
.map(|(j, ww)| ww[i] * next_errors[j])
|
||||||
@ -14,19 +14,19 @@ fn calc_error<NWA: Index<usize, Output = f64>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_error<NWA: IndexMut<usize, Output = f64>>(
|
pub(super) fn apply_error<NWA: IndexMut<usize, Output = f64>>(
|
||||||
n: f64,
|
n: f64,
|
||||||
errors: &[f64],
|
next_errors: &[f64],
|
||||||
weights: &mut [NWA],
|
next_weights: &mut [NWA],
|
||||||
current_potentials: &[f64],
|
current_outputs: &[f64],
|
||||||
next_potentials: &[f64],
|
next_potentials: &[f64],
|
||||||
f: impl Fn(f64) -> f64,
|
|
||||||
f1: impl Fn(f64) -> f64,
|
f1: impl Fn(f64) -> f64,
|
||||||
) {
|
) {
|
||||||
for i in 0..current_potentials.len() {
|
for i in 0..current_outputs.len() {
|
||||||
for j in 0..next_potentials.len() {
|
for j in 0..next_potentials.len() {
|
||||||
let dw = n * errors[j] * f1(next_potentials[j]) * f(current_potentials[i]);
|
let dw = n * next_errors[j] * f1(next_potentials[j]) * current_outputs[i];
|
||||||
weights[j][i] += dw;
|
next_weights[j][i] += dw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
mod compute;
|
mod compute;
|
||||||
mod fix;
|
mod fix;
|
||||||
|
mod net;
|
||||||
|
|||||||
75
lab4/src/algo/net.rs
Normal file
75
lab4/src/algo/net.rs
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
use crate::algo::compute::compute;
|
||||||
|
use crate::algo::fix::{apply_error, calc_error};
|
||||||
|
use std::ops::{Div, Mul};
|
||||||
|
|
||||||
|
struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
|
||||||
|
inner_layer_size: usize,
|
||||||
|
pub input_data: [f64; InputLayerSize],
|
||||||
|
i_to_h_weights: Vec<[f64; InputLayerSize]>,
|
||||||
|
hidden_potentials: Vec<f64>,
|
||||||
|
hidden_data: Vec<f64>,
|
||||||
|
h_to_o_weights: [Vec<f64>; OutputLayerSize],
|
||||||
|
output_potentials: [f64; OutputLayerSize],
|
||||||
|
pub output_data: [f64; OutputLayerSize],
|
||||||
|
output_errors: [f64; OutputLayerSize],
|
||||||
|
hidden_errors: Vec<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
||||||
|
Net<InputLayerSize, OutputLayerSize>
|
||||||
|
{
|
||||||
|
fn _sigmoid(x: f64) -> f64 {
|
||||||
|
return 1.0.div(1.0 + (-x).exp());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _sigmoidDerivative(x: f64) -> f64 {
|
||||||
|
return x.mul(1.0 - x);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compute(&mut self) {
|
||||||
|
compute(
|
||||||
|
&self.input_data,
|
||||||
|
self.i_to_h_weights.as_slice(),
|
||||||
|
self.hidden_potentials.as_mut_slice(),
|
||||||
|
self.hidden_data.as_mut_slice(),
|
||||||
|
Self::_sigmoid,
|
||||||
|
);
|
||||||
|
|
||||||
|
compute(
|
||||||
|
self.hidden_data.as_slice(),
|
||||||
|
&self.h_to_o_weights,
|
||||||
|
&mut self.output_potentials,
|
||||||
|
&mut self.output_data,
|
||||||
|
Self::_sigmoid,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fix(&mut self, expected: &[f64; OutputLayerSize], n: f64) {
|
||||||
|
for (i, (a, e)) in self.output_data.iter().zip(expected).enumerate() {
|
||||||
|
self.output_errors[i] = e - a;
|
||||||
|
}
|
||||||
|
|
||||||
|
calc_error(
|
||||||
|
&self.output_errors,
|
||||||
|
&self.h_to_o_weights,
|
||||||
|
self.hidden_errors.as_mut_slice(),
|
||||||
|
);
|
||||||
|
|
||||||
|
apply_error(
|
||||||
|
n,
|
||||||
|
&self.output_errors,
|
||||||
|
&mut self.h_to_o_weights,
|
||||||
|
self.hidden_data.as_slice(),
|
||||||
|
&self.output_potentials,
|
||||||
|
Self::_sigmoidDerivative
|
||||||
|
);
|
||||||
|
apply_error(
|
||||||
|
n,
|
||||||
|
self.hidden_errors.as_slice(),
|
||||||
|
self.i_to_h_weights.as_mut_slice(),
|
||||||
|
&self.input_data,
|
||||||
|
self.hidden_potentials.as_slice(),
|
||||||
|
Self::_sigmoidDerivative
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user