Skip to content

Commit 97a9f41

Browse files
author
Ashley Scillitoe
authored
Use PyTorch UAE in imdb example (#705)
1 parent 3daf62c commit 97a9f41

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

doc/source/examples/cd_text_imdb.ipynb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -974,15 +974,14 @@
974974
"source": [
975975
"from alibi_detect.cd.pytorch import preprocess_drift\n",
976976
"from alibi_detect.models.pytorch import TransformerEmbedding\n",
977+
"from alibi_detect.cd.pytorch import UAE\n",
977978
"\n",
979+
"# Embedding model\n",
978980
"embedding_pt = TransformerEmbedding(model_name, emb_type, layers)\n",
979981
"\n",
980-
"model = nn.Sequential(\n",
981-
" embedding_pt,\n",
982-
" nn.Linear(768, 256),\n",
983-
" nn.ReLU(),\n",
984-
" nn.Linear(256, enc_dim)\n",
985-
").to(device).eval()\n",
982+
"# PyTorch untrained autoencoder\n",
983+
"uae = UAE(input_layer=embedding_pt, shape=shape, enc_dim=enc_dim)\n",
984+
"model = uae.to(device).eval()\n",
986985
"\n",
987986
"# define preprocessing function\n",
988987
"preprocess_fn = partial(preprocess_drift, model=model, tokenizer=tokenizer, \n",

0 commit comments

Comments
 (0)