Skip to content

Commit 15bba11

Browse files
committed
Added sample custom tranform :)
1 parent 6c62cca commit 15bba11

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ const results = await squeezeNet(loader);
5353
```
5454

5555
### More Examples
56-
Additional examples of both setup and usage involving features like Torchvision Transforms and CUDA (in development) may be found [here](https://github.com/raghavmecheri/ptjs/tree/master/examples).
56+
Additional examples of both setup and usage involving features like Torchvision Transforms and CUDA (in development) may be found [here](https://github.com/raghavmecheri/pytorchjs/tree/master/examples/Usage).
5757

5858
## Key Features
5959
* Run your PyTorch models in a Javascript environment, without worrying about setting up Torchscript or downloading custom binaries

examples/Usage/Transforms.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Using a custom ptjs transform in your workflow
2+
3+
## Background
4+
This example assumes that you have a pre-trained torch model, that has been exported to [torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html), as well as an image dataset conforming to the structure that the ptjs DatasetFolder class expects. The structure may also be found below, for reference, and a sample script to export a torch model to torchscript can be found [here](../Exporting/Exporting.md)
5+
6+
## Required File Structure
7+
```
8+
root/class_x/xxx.ext
9+
root/class_x/xxy.ext
10+
root/class_x/xxz.ext
11+
12+
root/class_y/123.ext
13+
root/class_y/nsdf3.ext
14+
root/class_y/asd932_.ext
15+
```
16+
17+
## Use Case
18+
To add a custom transform function to your `transforms.Compose` call
19+
20+
## Sample Script
21+
### Import Dependancies
22+
```js
23+
import { torch, torchvision } from 'pytorchjs';
24+
25+
const { load } = torch;
26+
const { DataLoader } = torch.utils.data;
27+
const { ImageFolder } = torchvision.datasets;
28+
const { Compose, Resize, InvertAxes, Normalize } = torchvision.transforms;
29+
```
30+
31+
### Load Model
32+
```js
33+
const mymodel = load("mymodel.pt");
34+
```
35+
36+
### Declare a custom transform
37+
Custom transform classes must conform to the following requirements. A sample may be found below:
38+
- Be callable. [This](https://hackernoon.com/creating-callable-objects-in-javascript-d21l3te1) is a good starting point, I prefer to just extend the Javascript `Function` interface, like how I've done in [`transforms.js`](https://github.com/raghavmecheri/pytorchjs/blob/master/src/torchvision/transforms/transforms.js)
39+
- Return a `numjs` object
40+
```js
41+
export class RandomAdd extends Function {
42+
/**
43+
* Create a new callable RandomAdd object
44+
*/
45+
constructor() {
46+
super();
47+
// NOTE: This is just some template code I found at the linked Hackernoon page to make the function callable :)
48+
return new Proxy(this, {
49+
apply: (_target, _thisArg, argumentsList) => {
50+
const x = argumentsList[0];
51+
const output = this.__call__(x);
52+
return output;
53+
},
54+
});
55+
}
56+
57+
// All I'm doing here is adding a random integer to every value, because why not
58+
__call__ = (x) => x.add(Math.floor(Math.random() * 10));
59+
}
60+
```
61+
62+
### Implement this custom transform
63+
```js
64+
const transforms = new Compose([
65+
new Resize({height: 224, width: 224}),
66+
new RandomAdd(),
67+
new InvertAxes()
68+
]);
69+
```
70+
71+
### Load data and get inference results
72+
```js
73+
const loader = new DataLoader(new ImageFolder("dataset"), 1, transforms);
74+
const results = await mymodel(loader);
75+
```

0 commit comments

Comments
 (0)