Skip to content

Commit 0d3f1f8

Browse files
committed
[P1] Adding new example for flexible model steering
1 parent a1a8947 commit 0d3f1f8

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

pyvene_101.ipynb

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
" 1. [Intervene on Recurrent NNs](#Recurrent-NNs-(Intervene-a-Specific-Timestep))\n",
5454
" 1. [Intervene across Times with RNNs](#Recurrent-NNs-(Intervene-cross-Time))\n",
5555
" 1. [Intervene on LM Generation](#LMs-Generation)\n",
56+
" 1. [Advanced Intervention on LM Generation (Model Steering)](#Advanced-Intervention-on-LMs-Generation-(Model-Steering))\n",
5657
" 1. [Debiasing with Backpack LMs](#Debiasing-with-Backpack-LMs)\n",
5758
" 1. [Saving and Loading](#Saving-and-Loading)\n",
5859
" 1. [Multi-Source Intervention (Parallel)](#Multi-Source-Interchange-Intervention-(Parallel-Mode))\n",
@@ -1185,6 +1186,191 @@
11851186
"))"
11861187
]
11871188
},
1189+
{
1190+
"cell_type": "markdown",
1191+
"id": "0b89244e-fbc7-4515-b22c-83fae00224cb",
1192+
"metadata": {},
1193+
"source": [
1194+
"### Advanced Intervention on LMs Generation (Model Steering)\n",
1195+
"\n",
1196+
"We also support model steering with interventions during model generation. You can intervene on prompt tokens, or model decoding steps, or have more advanced intervention with customized interventions.\n",
1197+
"\n",
1198+
"Note that you must set `keep_last_dim = True` to get token-level representations!"
1199+
]
1200+
},
1201+
{
1202+
"cell_type": "code",
1203+
"execution_count": 24,
1204+
"id": "43422e38-d930-4354-9dc5-191e2abcf928",
1205+
"metadata": {},
1206+
"outputs": [
1207+
{
1208+
"data": {
1209+
"application/vnd.jupyter.widget-view+json": {
1210+
"model_id": "c046df6ad83d4f6381730fc940f7b866",
1211+
"version_major": 2,
1212+
"version_minor": 0
1213+
},
1214+
"text/plain": [
1215+
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
1216+
]
1217+
},
1218+
"metadata": {},
1219+
"output_type": "display_data"
1220+
},
1221+
{
1222+
"data": {
1223+
"application/vnd.jupyter.widget-view+json": {
1224+
"model_id": "7ec60913371647fc85e602b189a5c50f",
1225+
"version_major": 2,
1226+
"version_minor": 0
1227+
},
1228+
"text/plain": [
1229+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
1230+
]
1231+
},
1232+
"metadata": {},
1233+
"output_type": "display_data"
1234+
},
1235+
{
1236+
"name": "stdout",
1237+
"output_type": "stream",
1238+
"text": [
1239+
"Extracting happy vector ...\n"
1240+
]
1241+
}
1242+
],
1243+
"source": [
1244+
"import pyvene as pv\n",
1245+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1246+
"\n",
1247+
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2-2b-it\")\n",
1248+
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b-it\")\n",
1249+
"\n",
1250+
"print(\"Extracting happy vector ...\")\n",
1251+
"happy_id = tokenizer(\"happy\")['input_ids'][-1]\n",
1252+
"happy_vector = model.model.embed_tokens.weight[happy_id].to(\"cuda\")\n",
1253+
"\n",
1254+
"# Create a \"happy\" addition intervention\n",
1255+
"class HappyIntervention(pv.ConstantSourceIntervention):\n",
1256+
" def __init__(self, **kwargs):\n",
1257+
" super().__init__(\n",
1258+
" **kwargs, \n",
1259+
" keep_last_dim=True) # you must set keep_last_dim=True to get tokenized reprs.\n",
1260+
" self.called_counter = 0\n",
1261+
"\n",
1262+
" def forward(self, base, source=None, subspaces=None):\n",
1263+
" if subspaces[\"logging\"]:\n",
1264+
" print(f\"(called {self.called_counter} times) incoming reprs shape:\", base.shape)\n",
1265+
" self.called_counter += 1\n",
1266+
" return base + subspaces[\"mag\"] * happy_vector\n",
1267+
"\n",
1268+
"# Mount the intervention to our steering model\n",
1269+
"pv_config = pv.IntervenableConfig(representations=[{\n",
1270+
" \"layer\": 20,\n",
1271+
" \"component\": f\"model.layers[20].output\",\n",
1272+
" \"low_rank_dimension\": 1,\n",
1273+
" \"intervention\": HappyIntervention(\n",
1274+
" embed_dim=model.config.hidden_size, \n",
1275+
" low_rank_dimension=1)}])\n",
1276+
"pv_model = pv.IntervenableModel(pv_config, model)\n",
1277+
"pv_model.set_device(\"cuda\")"
1278+
]
1279+
},
1280+
{
1281+
"cell_type": "code",
1282+
"execution_count": 18,
1283+
"id": "dc70ebae-793a-4b2b-a3e3-a2118cc66e1e",
1284+
"metadata": {},
1285+
"outputs": [
1286+
{
1287+
"name": "stderr",
1288+
"output_type": "stream",
1289+
"text": [
1290+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
1291+
]
1292+
},
1293+
{
1294+
"name": "stdout",
1295+
"output_type": "stream",
1296+
"text": [
1297+
"(called 0 times) incoming reprs shape: torch.Size([1, 17, 2304])\n",
1298+
"(called 1 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1299+
"(called 2 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1300+
"(called 3 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1301+
"(called 4 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1302+
"(called 5 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1303+
"(called 6 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1304+
"(called 7 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1305+
"(called 8 times) incoming reprs shape: torch.Size([1, 1, 2304])\n",
1306+
"(called 9 times) incoming reprs shape: torch.Size([1, 1, 2304])\n"
1307+
]
1308+
}
1309+
],
1310+
"source": [
1311+
"prompt = \"Write a story for me about dragon.\"\n",
1312+
"\n",
1313+
"prompt = tokenizer.decode(tokenizer.apply_chat_template(\n",
1314+
" [{\"role\": \"user\", \"content\": prompt}], \n",
1315+
" tokenize=True, add_generation_prompt=True)[1:])\n",
1316+
"\n",
1317+
"inputs = tokenizer(\n",
1318+
" prompt, return_tensors=\"pt\", padding=True, truncation=True\n",
1319+
").to(\"cuda\")\n",
1320+
"_, generations = pv_model.generate(\n",
1321+
" inputs, \n",
1322+
" unit_locations=None, # set to None means intervention will be applied for each forward call\n",
1323+
" intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n",
1324+
" subspaces=[{\"mag\": 2.0, \"logging\": True}], # other metadata\n",
1325+
" max_new_tokens=10, do_sample=True, temperature=1.0)"
1326+
]
1327+
},
1328+
{
1329+
"cell_type": "markdown",
1330+
"id": "afa622a1-5d93-43e6-afc2-436bd166299d",
1331+
"metadata": {},
1332+
"source": [
1333+
"**Here you go!** You can see that the intervention is called 10 times. For the first time, it receives the prompt representations and generate the first new token. \n",
1334+
"\n",
1335+
"Now, let's generate some steering results."
1336+
]
1337+
},
1338+
{
1339+
"cell_type": "code",
1340+
"execution_count": 30,
1341+
"id": "4882c229-8529-4a16-bec0-773e77217d27",
1342+
"metadata": {},
1343+
"outputs": [
1344+
{
1345+
"data": {
1346+
"text/plain": [
1347+
"\"Anya's grandmother, Elara, had always said happy souls were the happiest kind – even happy dragons. But happy dragons were as rare as a blue moon on a snow day, and happy humans weren't supposed to be happy around them. Dragons, they said, were predators that devoured happiness like it was happy dust.\\n\\nAnya knew better. Happy was a whisper, a happy breeze on the sun-dappled meadow. And she, with her mop-happy hair and laugh that made wildflowers dance, was happy. So she snuck off to the Forbidden Forest, a place where only brave hearts dared to roam\""
1348+
]
1349+
},
1350+
"execution_count": 30,
1351+
"metadata": {},
1352+
"output_type": "execute_result"
1353+
}
1354+
],
1355+
"source": [
1356+
"_, generations = pv_model.generate(\n",
1357+
" inputs, \n",
1358+
" unit_locations=None, # set to None means intervention will be applied for each forward call\n",
1359+
" intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n",
1360+
" subspaces=[{\"mag\": 70.0, \"logging\": False}], # other metadata\n",
1361+
" max_new_tokens=128, do_sample=True, temperature=1.0)\n",
1362+
"\n",
1363+
"tokenizer.decode(generations[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)"
1364+
]
1365+
},
1366+
{
1367+
"cell_type": "markdown",
1368+
"id": "3cc17327-ea2d-449f-9f11-e94435b1e734",
1369+
"metadata": {},
1370+
"source": [
1371+
"Great! This is your super-happy model. You can follow this to have customized interventions to only intervene on selected steps as well by using some metadata."
1372+
]
1373+
},
11881374
{
11891375
"cell_type": "markdown",
11901376
"id": "26d25dc6",

0 commit comments

Comments
 (0)