refact: added detect to predict to class

also defined ObjectDetector as a UseCase
This commit is contained in:
Joshua Shoemaker 2021-01-13 16:51:11 -06:00
parent 685c57d3a9
commit d201c36861
3 changed files with 20 additions and 24 deletions

View File

@ -1,5 +1,6 @@
import * as tf from '@tensorflow/tfjs' import * as tf from '@tensorflow/tfjs'
import * as cocossd from '@tensorflow-models/coco-ssd' import * as cocossd from '@tensorflow-models/coco-ssd'
import PredictedObject from '../Models/PredictedObject'
let instance: ObjectDetector | null = null let instance: ObjectDetector | null = null
@ -16,7 +17,19 @@ class ObjectDetector {
return instance return instance
} }
private doesPredictionPassFilterPredicates (prediction: cocossd.DetectedObject): boolean { private convertDetectedToPredictedObjects = (detectedObjects: cocossd.DetectedObject[]) => {
const predictedObjects: PredictedObject[] = detectedObjects.map(p => new PredictedObject({
xOrigin: p.bbox[0],
yOrigin: p.bbox[1],
width: p.bbox[2],
height: p.bbox[3],
class: p.class
}))
return predictedObjects
}
private doesDetectionPassFilterPredicates (prediction: cocossd.DetectedObject): boolean {
let failedPredictions = [] let failedPredictions = []
this.filterPredicates.forEach(filter => { this.filterPredicates.forEach(filter => {
if (!filter(prediction)) failedPredictions.push(filter) if (!filter(prediction)) failedPredictions.push(filter)
@ -28,10 +41,11 @@ class ObjectDetector {
public predictImageStream = async (videoImage: ImageData) => { public predictImageStream = async (videoImage: ImageData) => {
const mlModel = await this.loadMlModel() const mlModel = await this.loadMlModel()
const predictions = await mlModel.detect(videoImage) const detectedObjects = await mlModel.detect(videoImage)
const filteredPredictions = predictions.filter(p => this.doesPredictionPassFilterPredicates(p)) const filteredDetections = detectedObjects.filter(p => this.doesDetectionPassFilterPredicates(p))
const predictions = this.convertDetectedToPredictedObjects(filteredDetections)
return filteredPredictions return predictions
} }
public async loadMlModel (): Promise<cocossd.ObjectDetection> { public async loadMlModel (): Promise<cocossd.ObjectDetection> {

View File

@ -1,16 +0,0 @@
import { DetectedObject } from "@tensorflow-models/coco-ssd"
import PredictedObject from '../Models/PredictedObject'
const convertDetectedToPredictedObjects = (detectedObjects: DetectedObject[]) => {
const predictedObjects: PredictedObject[] = detectedObjects.map(p => new PredictedObject({
xOrigin: p.bbox[0],
yOrigin: p.bbox[1],
width: p.bbox[2],
height: p.bbox[3],
class: p.class
}))
return predictedObjects
}
export default convertDetectedToPredictedObjects

View File

@ -1,8 +1,7 @@
import { DetectedObject } from "@tensorflow-models/coco-ssd" import { DetectedObject } from "@tensorflow-models/coco-ssd"
import PredictedObjectCollectionController from "./Controllers/PredictedObjectCollectionController" import PredictedObjectCollectionController from "./Controllers/PredictedObjectCollectionController"
import VideoController from './Controllers/VideoController' import VideoController from './Controllers/VideoController'
import ObjectDetector from './Models/ObjectDetector' import ObjectDetector from './UseCases/ObjectDetector'
import convertDetectedtoPredictedObject from './UseCases/convertDetectedToPredictedObjects'
const defaultPredictions = [ const defaultPredictions = [
(prediction: DetectedObject) => prediction.score > 0.6, (prediction: DetectedObject) => prediction.score > 0.6,
@ -29,8 +28,7 @@ class App {
return return
} }
const detectedObjects: DetectedObject[] = await this.objectDetector.predictImageStream(imageData) const predictedObjects = await this.objectDetector.predictImageStream(imageData)
const predictedObjects = convertDetectedtoPredictedObject(detectedObjects)
this.predictedObjectCollectionController.predictedObjects = predictedObjects this.predictedObjectCollectionController.predictedObjects = predictedObjects
window.requestAnimationFrame(this.predictImage) window.requestAnimationFrame(this.predictImage)