|
| 1 | +# Using a custom pytorchjs transform in your workflow |
| 2 | + |
| 3 | +## Background |
| 4 | +This example assumes that you have a pre-trained torch model, that has been exported to [torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html), as well as an image dataset conforming to the structure that the pytorchjs DatasetFolder class expects. The structure may also be found below, for reference, and a sample script to export a torch model to torchscript can be found [here](../Exporting/Exporting.md) |
| 5 | + |
| 6 | +## Required File Structure |
| 7 | +``` |
| 8 | +root/class_x/xxx.ext |
| 9 | +root/class_x/xxy.ext |
| 10 | +root/class_x/xxz.ext |
| 11 | +
|
| 12 | +root/class_y/123.ext |
| 13 | +root/class_y/nsdf3.ext |
| 14 | +root/class_y/asd932_.ext |
| 15 | +``` |
| 16 | + |
| 17 | +## Use Case |
| 18 | +To add a custom transform function to your `transforms.Compose` call |
| 19 | + |
| 20 | +## Sample Script |
| 21 | +### Import Dependancies |
| 22 | +```js |
| 23 | +import { torch, torchvision } from 'pytorchjs'; |
| 24 | + |
| 25 | +const { load } = torch; |
| 26 | +const { DataLoader } = torch.utils.data; |
| 27 | +const { ImageFolder } = torchvision.datasets; |
| 28 | +const { Compose, Resize, InvertAxes, Normalize } = torchvision.transforms; |
| 29 | +``` |
| 30 | + |
| 31 | +### Load Model |
| 32 | +```js |
| 33 | +const mymodel = load("mymodel.pt"); |
| 34 | +``` |
| 35 | + |
| 36 | +### Declare a custom transform |
| 37 | +Custom transform classes must conform to the following requirements. A sample may be found below: |
| 38 | +- Be callable. [This](https://hackernoon.com/creating-callable-objects-in-javascript-d21l3te1) is a good starting point, I prefer to just extend the Javascript `Function` interface, like how I've done in [`transforms.js`](https://github.com/raghavmecheri/pytorchjs/blob/master/src/torchvision/transforms/transforms.js) |
| 39 | +- Return a `numjs` object |
| 40 | +```js |
| 41 | +export class RandomAdd extends Function { |
| 42 | + /** |
| 43 | + * Create a new callable RandomAdd object |
| 44 | + */ |
| 45 | + constructor() { |
| 46 | + super(); |
| 47 | + // NOTE: This is just some template code I found at the linked Hackernoon page to make the function callable :) |
| 48 | + return new Proxy(this, { |
| 49 | + apply: (_target, _thisArg, argumentsList) => { |
| 50 | + const x = argumentsList[0]; |
| 51 | + const output = this.__call__(x); |
| 52 | + return output; |
| 53 | + }, |
| 54 | + }); |
| 55 | + } |
| 56 | + |
| 57 | + // All I'm doing here is adding a random integer to every value, because why not |
| 58 | + __call__ = (x) => x.add(Math.floor(Math.random() * 10)); |
| 59 | +} |
| 60 | +``` |
| 61 | + |
| 62 | +### Implement this custom transform |
| 63 | +```js |
| 64 | +const transforms = new Compose([ |
| 65 | + new Resize({height: 224, width: 224}), |
| 66 | + new RandomAdd(), |
| 67 | + new InvertAxes() |
| 68 | +]); |
| 69 | +``` |
| 70 | + |
| 71 | +### Load data and get inference results |
| 72 | +```js |
| 73 | +const loader = new DataLoader(new ImageFolder("dataset"), 1, transforms); |
| 74 | +const results = await mymodel(loader); |
| 75 | +``` |
0 commit comments