Skip to content

Commit

Permalink
Add training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed Feb 14, 2024
1 parent 54b29d0 commit 608a175
Show file tree
Hide file tree
Showing 8 changed files with 783 additions and 21 deletions.
10 changes: 10 additions & 0 deletions .docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#docker build -t mam10eks/tira-for-school:0.0.1 -f .docker/Dockerfile .
FROM mam10eks/github-page-tutorial:0.0.2

ADD pupil-submission-page/package-lock.json pupil-submission-page/yarn.lock pupil-submission-page/package.json /tmp/

RUN cd /tmp &&\
cp -r /usr/local/lib/node_modules . \
&& npm install \
&& rm -Rf /usr/local/lib/node_modules \
&& mv node_modules /usr/local/lib/node_modules
2 changes: 1 addition & 1 deletion pupil-submission-page/.devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"image": "mam10eks/github-page-tutorial:0.0.2",
"image": "mam10eks/tira-for-school:0.0.1",
"customizations": {
"vscode": {
"extensions": ["ms-python.python", "ms-python.vscode-pylance", "ms-toolsai.jupyter"]
Expand Down
409 changes: 397 additions & 12 deletions pupil-submission-page/package-lock.json

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pupil-submission-page/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
},
"dependencies": {
"@mdi/font": "7.0.96",
"@tensorflow-models/knn-classifier": "^1.2.6",
"@tensorflow-models/mobilenet": "^2.1.1",
"@tensorflow/tfjs": "^4.17.0",
"core-js": "^3.29.0",
"roboto-fontface": "*",
"vue": "^3.2.0",
Expand Down
2 changes: 1 addition & 1 deletion pupil-submission-page/src/components/MainStudentPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<template v-slot:item.2 :rules="[() => false]" error="tmp"><Stepper_2 :klasse_vorfahrt_strasse="klasse_vorfahrt_strasse" :klasse_vorfahrt_gewaehren="klasse_vorfahrt_gewaehren"/></template>

<template v-slot:item.3><Stepper_3 @model-trained="modelTrained"/></template>
<template v-slot:item.3><Stepper_3 @model-trained="modelTrained" :klasse_vorfahrt_strasse="klasse_vorfahrt_strasse" :klasse_vorfahrt_gewaehren="klasse_vorfahrt_gewaehren"/></template>

<template v-slot:item.4><Stepper_4/></template>

Expand Down
30 changes: 28 additions & 2 deletions pupil-submission-page/src/components/Stepper_3.vue
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
</template>

<script lang="ts">
import {load_model} from '@/training.ts'
function Sleep(milliseconds) {
return new Promise(resolve => setTimeout(resolve, milliseconds));
}
export default {
props: ['klasse_vorfahrt_strasse' , 'klasse_vorfahrt_gewaehren'],
data: () => ({
cols: 2,
epochs: 50,
Expand All @@ -51,12 +53,36 @@ export default {
}),
methods: {
async train() {
let model = await load_model()
for (let i=1; i< 100; i++) {
this.training_progress = i
await Sleep(75)
for (let class_0 of this.klasse_vorfahrt_strasse) {
model.train(class_0.src, 0)
}
for (let class_1 of this.klasse_vorfahrt_gewaehren) {
model.train(class_1.src, 1)
}
await Sleep(1)
}
for (let class_0 of this.klasse_vorfahrt_strasse) {
let prediction = await model.predict(class_0.src)
console.log('Class 0: ' + prediction.classIndex)
console.log(prediction.confidences)
}
for (let class_1 of this.klasse_vorfahrt_gewaehren) {
let prediction = await model.predict(class_1.src)
console.log('Class 1: ' + prediction.classIndex)
console.log(prediction.confidences)
}
this.training_progress = 0
this.$emit('model-trained', {'tmp': '1'})
this.$emit('model-trained', {'tmp': '1', 'model': model})
}
}
}
Expand Down
58 changes: 58 additions & 0 deletions pupil-submission-page/src/training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import * as mobilenetModule from '@tensorflow-models/mobilenet'
import * as tf from '@tensorflow/tfjs'
import * as knnClassifier from '@tensorflow-models/knn-classifier';

let mobileNetPromise = mobilenetModule.load()
let mobileNet: any = null;


export async function load_model() {
if (mobileNet == null) {
console.log('ToDo: Load Mobile net...')
mobileNet = await mobileNetPromise
}

class Main {
mobileNet: any
knn: any

constructor() {
this.mobileNet = mobileNet
this.knn = knnClassifier.create();
}

infer_with_mobilenet(img_src: string) {
const img = new Image()
img.src = img_src;
img.width = 227
img.height = 227

return this.mobileNet.infer(tf.browser.fromPixels(img), 'conv_preds')
}

train(image: string, class_id: 0|1) {
const infer = this.infer_with_mobilenet(image)
this.knn.addExample(infer, class_id)
}

async predict(image: string) {
const logits = this.infer_with_mobilenet(image)
const ret = await this.knn.predictClass(logits, 10)
return ret
}
}

return new Main()
}

mobileNetPromise.then((mobilenet:any) => {
mobileNet = mobilenet
console.log('mobilenet loaded' + mobilenet)
})

export async function predict(image: HTMLImageElement) {
const img = tf.browser.fromPixels(image)
const infer = mobileNet.infer(img, 'conv_preds')
const result = await infer.data()
return result
}
Loading

0 comments on commit 608a175

Please sign in to comment.