@marcellejs/onnx: Components
onnxModel
function onnxModel({
inputType: 'image' | 'generic';
taskType: 'classification' | 'generic';
segmentationOptions?: {
output?: 'image' | 'tensor';
inputShape: number[];
};
}): OnnxModel;This component allows to make predictions using pre-trained models in the ONNX format, using onnxruntime-web. The default backend for inference is wasm, as it provides a wider operator support.
The implementation currently supports tensors as input, formatted as nested number arrays, and two types of task (classification, generic prediction). Pre-trained models can be loaded either by URL, or through file upload, for instance using the fileUpload component.
Such generic models cannot be trained.
WARNING
onnxruntime-web is not included in the build, to use the onnxModel component, add the following line to your index.html:
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.2/dist/ort.wasm.min.js"></script>Methods
.loadFromFile()
loadFromFile(file: File): Promise<void>Load a pre-trained ONNX model from a *.onnx file.
.loadFromUrl()
loadFromUrl(url: string): Promise<void>Load a pre-trained ONNX model from a URL.
.predict()
predict(input: InputTypes[InputType]): Promise<PredictionTypes[TaskType]>Make a prediction from an input instance, which type depends on the inputType specified in the constructor. The method is asynchronous and returns a promise that resolves with the results of the prediction.
Input types can be:
ImageDataif the model was instanciated withinputType: 'image'TensorLike(= array) if the model was instanciated withinputType: 'generic'
Output types can be:
ClassifierResultsif the model was instanciated withtaskType: 'classification'TensorLikeif the model was instanciated withtaskType: 'generic'
Where classifier results have the following interface:
interface ClassifierResults {
label: string;
confidences: { [key: string]: number };
}Example
const source = imageUpload();
const classifier = tfjsModel({
inputType: 'image',
taskType: 'classification',
});
classifier.loadFromUrl();
const predictionStream = source.$images.map(classifier.predict).awaitPromises();