Add feature extraction pipeline to mediaproc

This commit is contained in:
Liam 2025-01-08 21:35:50 -05:00
parent beb7cd49ab
commit 2bd7ddf9d1
13 changed files with 957 additions and 41 deletions

View file

@ -62,14 +62,16 @@ RUN cd /tmp \
COPY native/philomena /tmp/philomena
COPY docker/mediaproc/safe-rsvg-convert /usr/bin/safe-rsvg-convert
ADD https://github.com/liamwhite/philomena-ris-inference-toolkit/releases/download/v1.0/dinov2-with-registers-base.pt /usr/share/dinov2-with-registers-base.pt
RUN cd /tmp/philomena \
&& cargo build --release -p mediaproc_server \
&& cp target/release/mediaproc_server /usr/bin/mediaproc_server
&& cp target/release/mediaproc_server /usr/bin/mediaproc_server \
&& find target/release/build -regextype posix-extended -regex '^.*\.so(\.[0-9]+)*$' -exec cp '{}' /usr/lib/ ';'
# Set up unprivileged user account
RUN useradd -ms /bin/bash mediaproc
USER mediaproc
WORKDIR /home/mediaproc
ENV RUST_LOG=trace
CMD ["/usr/bin/mediaproc_server", "0.0.0.0:1500"]
CMD ["/usr/bin/mediaproc_server", "0.0.0.0:1500", "/usr/share/dinov2-with-registers-base.pt"]

View file

@ -12,6 +12,10 @@ defmodule Philomena.Native do
@spec camo_image_url(String.t()) :: String.t()
def camo_image_url(_uri), do: :erlang.nif_error(:nif_not_loaded)
@spec async_get_features(String.t(), String.t()) :: :ok
def async_get_features(_server_addr, _path),
do: :erlang.nif_error(:nif_not_loaded)
@spec async_process_command(String.t(), String.t(), [String.t()]) :: :ok
def async_process_command(_server_addr, _program, _arguments),
do: :erlang.nif_error(:nif_not_loaded)

View file

@ -7,11 +7,23 @@ defmodule PhilomenaMedia.Remote do
:ok = Philomena.Native.async_process_command(mediaproc_addr(), command, args)
receive do
{:command_reply, command_reply} ->
{:process_command_reply, command_reply} ->
{command_reply.stdout, command_reply.status}
end
end
@doc """
Gets a feature vector for the given image path to use in reverse image search.
"""
def get_features(path) do
:ok = Philomena.Native.async_get_features(mediaproc_addr(), path)
receive do
{:get_features_reply, get_features_reply} ->
get_features_reply
end
end
defp mediaproc_addr do
Application.get_env(:philomena, :mediaproc_addr)
end

View file

@ -17,6 +17,17 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
@ -128,6 +139,18 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bincode"
version = "1.3.3"
@ -137,12 +160,27 @@ dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bon"
version = "3.3.2"
@ -174,18 +212,51 @@ version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytemuck"
version = "1.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "byteorder-lite"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
[[package]]
name = "bzip2"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8"
dependencies = [
"bzip2-sys",
"libc",
]
[[package]]
name = "bzip2-sys"
version = "0.1.11+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "caseless"
version = "0.2.2"
@ -201,6 +272,8 @@ version = "1.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7"
dependencies = [
"jobserver",
"libc",
"shlex",
]
@ -210,6 +283,16 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clap"
version = "4.5.24"
@ -272,6 +355,21 @@ dependencies = [
"unicode_categories",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "cpufeatures"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.4.2"
@ -287,6 +385,22 @@ version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "darling"
version = "0.20.10"
@ -322,6 +436,15 @@ dependencies = [
"syn",
]
[[package]]
name = "deranged"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
dependencies = [
"powerfmt",
]
[[package]]
name = "derive_arbitrary"
version = "1.4.1"
@ -339,6 +462,17 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "339544cc9e2c4dc3fc7149fd630c5f22263a4fdf18a98afd0075784968b5cf00"
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@ -433,6 +567,15 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fdeflate"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
dependencies = [
"simd-adler32",
]
[[package]]
name = "flate2"
version = "1.0.35"
@ -547,6 +690,16 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.15"
@ -564,6 +717,16 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.15.2"
@ -576,6 +739,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "http"
version = "0.2.12"
@ -738,6 +910,20 @@ dependencies = [
"icu_properties",
]
[[package]]
name = "image"
version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b"
dependencies = [
"bytemuck",
"byteorder-lite",
"num-traits",
"png",
"zune-core",
"zune-jpeg",
]
[[package]]
name = "indexmap"
version = "2.7.0"
@ -748,6 +934,15 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "inout"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
dependencies = [
"generic-array",
]
[[package]]
name = "inventory"
version = "0.3.17"
@ -789,6 +984,15 @@ dependencies = [
"libc",
]
[[package]]
name = "jobserver"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
dependencies = [
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.76"
@ -855,6 +1059,16 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "mediaproc"
version = "0.1.0"
@ -881,8 +1095,10 @@ dependencies = [
"clap",
"env_logger",
"futures",
"image",
"mediaproc",
"tarpc",
"tch",
"tempfile",
"tokio",
"tracing",
@ -901,6 +1117,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
@ -914,6 +1131,52 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "object"
version = "0.36.7"
@ -989,6 +1252,29 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "password-hash"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "pbkdf2"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
dependencies = [
"digest",
"hmac",
"password-hash",
"sha2",
]
[[package]]
name = "percent-encoding"
version = "2.3.1"
@ -999,18 +1285,18 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
name = "philomena"
version = "0.3.0"
dependencies = [
"base64",
"base64 0.21.7",
"comrak",
"http",
"jemallocator",
"mediaproc",
"once_cell",
"regex",
"ring",
"ring 0.16.20",
"rustler",
"tokio",
"url",
"zip",
"zip 2.2.2",
]
[[package]]
@ -1045,6 +1331,31 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
[[package]]
name = "png"
version = "0.17.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
@ -1112,13 +1423,19 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "redox_syscall"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
dependencies = [
"bitflags",
"bitflags 2.6.0",
]
[[package]]
@ -1165,12 +1482,27 @@ dependencies = [
"cc",
"libc",
"once_cell",
"spin",
"untrusted",
"spin 0.5.2",
"untrusted 0.7.1",
"web-sys",
"winapi",
]
[[package]]
name = "ring"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
"getrandom",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
@ -1183,7 +1515,7 @@ version = "0.38.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6"
dependencies = [
"bitflags",
"bitflags 2.6.0",
"errno",
"libc",
"linux-raw-sys",
@ -1215,6 +1547,38 @@ dependencies = [
"syn",
]
[[package]]
name = "rustls"
version = "0.23.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b"
dependencies = [
"log",
"once_cell",
"ring 0.17.8",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pki-types"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
[[package]]
name = "rustls-webpki"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring 0.17.8",
"rustls-pki-types",
"untrusted 0.9.0",
]
[[package]]
name = "rustversion"
version = "1.0.19"
@ -1227,6 +1591,16 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "safetensors"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -1265,6 +1639,28 @@ dependencies = [
"serde",
]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
@ -1336,6 +1732,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
@ -1354,6 +1756,12 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.95"
@ -1412,6 +1820,23 @@ dependencies = [
"syn",
]
[[package]]
name = "tch"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb3500c87ef72447c23b33ed6f15fac45a616b09bcac53e62e0e4386bddb3b9d"
dependencies = [
"half",
"lazy_static",
"libc",
"ndarray",
"rand",
"safetensors",
"thiserror 1.0.69",
"torch-sys",
"zip 0.6.6",
]
[[package]]
name = "tempfile"
version = "3.15.0"
@ -1476,6 +1901,25 @@ dependencies = [
"once_cell",
]
[[package]]
name = "time"
version = "0.3.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21"
dependencies = [
"deranged",
"num-conv",
"powerfmt",
"serde",
"time-core",
]
[[package]]
name = "time-core"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "tinystr"
version = "0.7.6"
@ -1560,6 +2004,21 @@ dependencies = [
"tokio",
]
[[package]]
name = "torch-sys"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61b87ed41261d4278060c3ba3e735c224687cf312403e4565f2ca75310279d73"
dependencies = [
"anyhow",
"cc",
"libc",
"serde",
"serde_json",
"ureq",
"zip 0.6.6",
]
[[package]]
name = "tracing"
version = "0.1.41"
@ -1626,6 +2085,12 @@ version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a"
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicode-ident"
version = "1.0.14"
@ -1653,6 +2118,30 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
dependencies = [
"base64 0.22.1",
"flate2",
"log",
"once_cell",
"rustls",
"rustls-pki-types",
"serde",
"serde_json",
"url",
"webpki-roots",
]
[[package]]
name = "url"
version = "2.5.4"
@ -1688,6 +2177,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@ -1768,6 +2263,15 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi"
version = "0.3.9"
@ -1950,6 +2454,12 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zerovec"
version = "0.10.4"
@ -1972,6 +2482,26 @@ dependencies = [
"syn",
]
[[package]]
name = "zip"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
dependencies = [
"aes",
"byteorder",
"bzip2",
"constant_time_eq",
"crc32fast",
"crossbeam-utils",
"flate2",
"hmac",
"pbkdf2",
"sha1",
"time",
"zstd",
]
[[package]]
name = "zip"
version = "2.2.2"
@ -2002,3 +2532,47 @@ dependencies = [
"once_cell",
"simd-adler32",
]
[[package]]
name = "zstd"
version = "0.11.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "5.0.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
dependencies = [
"libc",
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.13+zstd.1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "zune-core"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a"
[[package]]
name = "zune-jpeg"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99a5bab8d7dedf81405c4bb1f2b83ea057643d9cb28778cea9eecddeedd2e028"
dependencies = [
"zune-core",
]

View file

@ -98,25 +98,29 @@ fn update_replacements(
Ok(())
}
fn context_with_1_hour_deadline() -> Context {
pub fn context_with_deadline(secs_from_now: u64) -> Context {
let mut context = Context::current();
context.deadline = Instant::now() + Duration::from_secs(60 * 60);
context.deadline = Instant::now() + Duration::from_secs(secs_from_now);
context
}
pub fn context_with_1_hour_deadline() -> Context {
context_with_deadline(60 * 60)
}
pub fn context_with_10_second_deadline() -> Context {
context_with_deadline(10)
}
pub async fn execute_command(
client: &MediaProcessorClient,
program: String,
arguments: Vec<String>,
ctx: Context,
) -> Result<CommandReply, ExecuteCommandError> {
let call_params = create_replacements(arguments.into_iter());
let (reply, file_map) = client
.execute_command(
context_with_1_hour_deadline(),
program,
call_params.arguments,
call_params.file_map,
)
.execute_command(ctx, program, call_params.arguments, call_params.file_map)
.await
.map_err(|_| ExecuteCommandError::UnknownError)??;

View file

@ -12,11 +12,16 @@ pub trait MediaProcessor {
arguments: Vec<String>,
file_map: FileMap,
) -> Result<(CommandReply, FileMap), ExecuteCommandError>;
/// Runs feature extraction on an image file bytes (PNG or JPEG).
async fn get_features(image: Vec<u8>) -> Result<Vec<f32>, FeatureExtractionError>;
}
/// Errors which can occur during command execution.
#[derive(Debug, Deserialize, Serialize)]
pub enum ExecuteCommandError {
/// Failed to connect to server.
ConnectionError,
/// Requested program was not allowed to be executed.
UnpermittedProgram(String),
/// Failed to launch program.
@ -31,6 +36,19 @@ pub enum ExecuteCommandError {
UnknownError,
}
/// Errors which can occur during image feature extraction.
#[derive(Debug, Deserialize, Serialize)]
pub enum FeatureExtractionError {
/// Failed to connect to server.
ConnectionError,
/// Generic filesystem error.
LocalFilesystemError,
/// Unrecognized image format.
UnknownImageFormat,
/// Failed to decode the image.
ImageDecodeError,
}
/// Enumeration of permitted program names.
pub static PERMITTED_PROGRAMS: Lazy<HashSet<&'static str>> = Lazy::new(|| {
vec![

View file

@ -2,7 +2,7 @@ use std::io::Write;
use std::process::ExitCode;
use clap::{Parser, Subcommand};
use mediaproc::client::{connect_to_socket_server, execute_command};
use mediaproc::client;
use mediaproc::MediaProcessorClient;
#[derive(Parser, Debug)]
@ -28,12 +28,17 @@ enum InvocationType {
/// Arguments to pass to program.
args: Vec<String>,
},
/// Get DINOv2 features from the given image file (PNG or JPEG).
ExtractFeatures {
/// Filename to extract from.
file_name: String,
},
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> ExitCode {
let args = Arguments::parse();
let client = connect_to_socket_server(&args.server_addr)
let client = client::connect_to_socket_server(&args.server_addr)
.await
.expect("failed to connect to server");
@ -41,6 +46,9 @@ async fn main() -> ExitCode {
InvocationType::ExecuteCommand { program, args } => {
run_command_client(&client, program, args).await
}
InvocationType::ExtractFeatures { file_name } => {
run_feature_extraction_client(&client, file_name).await
}
}
}
@ -49,7 +57,10 @@ async fn run_command_client(
program: String,
args: Vec<String>,
) -> ExitCode {
let reply = execute_command(client, program, args).await.unwrap();
let ctx = client::context_with_1_hour_deadline();
let reply = client::execute_command(client, program, args, ctx)
.await
.unwrap();
write_then_drop(std::io::stderr(), reply.stderr);
write_then_drop(std::io::stdout(), reply.stdout);
@ -60,3 +71,29 @@ async fn run_command_client(
fn write_then_drop(mut stream: impl Write, data: Vec<u8>) {
stream.write_all(&data).unwrap()
}
async fn run_feature_extraction_client(
client: &MediaProcessorClient,
file_name: String,
) -> ExitCode {
let image = std::fs::read(file_name).unwrap();
let features = client
.get_features(client::context_with_10_second_deadline(), image)
.await
.unwrap()
.unwrap();
// Manual intersperse implementation, until rust adds it properly
let mut started = false;
for component in features {
if started {
print!(" {}", component);
} else {
print!("{}", component);
started = true;
}
}
println!();
ExitCode::SUCCESS
}

View file

@ -7,8 +7,10 @@ edition = "2021"
env_logger = "0.11"
clap = { version = "4.5", features = ["derive"] }
futures = "0.3"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
mediaproc = { path = "../mediaproc" }
tarpc = { version = "0.35", features = ["full"] }
tch = { version = "0.18.1", features = ["download-libtorch"] }
tempfile = "3"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"

View file

@ -0,0 +1,106 @@
use std::io::Cursor;
use tch::{CModule, Device, IValue, Tensor};
use super::io;
use crate::FeatureExtractionError;
/// Each DINOv2 patch is 14x14
pub const PATCH_DIM: i64 = 14;
pub struct ModelResult {
pub patches: (i64, i64),
pub image: Tensor,
pub features: Tensor,
pub last_hidden_state: Tensor,
}
fn infer(image: &Tensor, model: &CModule) -> (Tensor, Tensor) {
// These cases intentionally panic because their outputs depend on the model,
// not on the image file input, and invalid model format is not recoverable.
let output = model
.forward_is(&[IValue::Tensor(image.shallow_clone())])
.unwrap();
let mut results = match output {
IValue::Tuple(elements) if elements.len() == 2 => elements,
_ => unreachable!("expected (last_hidden_state, pooler_output)"),
};
let mut results = results.drain(..);
match (results.next(), results.next()) {
(Some(IValue::Tensor(last_hidden_state)), Some(IValue::Tensor(pooler_output))) => {
(last_hidden_state, pooler_output)
}
_ => unreachable!("expected 2-tuple of tensors"),
}
}
fn scaled_result(pooler_output: &Tensor) -> Tensor {
let scaled_norm = pooler_output.norm().pow_tensor_scalar(-1);
pooler_output.multiply(&scaled_norm)
}
pub fn get_model_result<R>(
image: R,
model: &CModule,
device: Device,
) -> Result<ModelResult, FeatureExtractionError>
where
R: std::io::Read + std::io::Seek,
{
// Get image and and dimensions for calculation.
let image = io::load_image(image, device)?;
// Features are unstable across different global scales, and
// somewhat stable across dimensional scales.
//
// Use 18 (252x252) instead of 16 (224x224) to produce a more detailed
// result and attention map at almost exactly the same computational cost.
//
// It is possible for highly non-square models to have meaningful feature extraction,
// but in practice it makes no difference identifying scales which keep the aspect
// ratio, and does not produce feature vectors which are similar enough to identify
// crops.
let image_scale = 1;
let patches = (18 * image_scale, 18 * image_scale);
// Scale image into appropriate shape.
let image = io::resize_image_by_patch_count(image, patches, PATCH_DIM);
// The pooler output is the [CLS] token generated by the model.
// It contains high-quality, robust features from the input image.
let (last_hidden_state, pooler_output) = infer(&image, model);
Ok(ModelResult {
patches,
image,
features: scaled_result(&pooler_output.squeeze()),
last_hidden_state,
})
}
pub struct Executor {
device: Device,
model: CModule,
}
impl Executor {
pub fn new(model_path: &str) -> Option<Self> {
let (device, model) = io::device_and_model(model_path)?;
Some(Self { device, model })
}
pub fn extract(&self, image: &[u8]) -> Result<Vec<f32>, FeatureExtractionError> {
let image = Cursor::new(image);
let model_result = get_model_result(image, &self.model, self.device)?;
let features = model_result
.features
.iter::<f64>()
.unwrap()
.map(|f| f as f32)
.collect();
Ok(features)
}
}

View file

@ -0,0 +1,100 @@
use image::{DynamicImage, ImageBuffer, ImageReader, Pixel};
use std::io::BufReader;
use tch::{CModule, Device, Tensor};
use crate::FeatureExtractionError;
pub fn device_and_model(model_path: &str) -> Option<(Device, CModule)> {
let device = Device::cuda_if_available();
let model = CModule::load_on_device(model_path, device).ok()?;
Some((device, model))
}
fn into_tensor<P: Pixel<Subpixel = f32>>(
image: ImageBuffer<P, Vec<f32>>,
device: Device,
) -> Tensor {
let w: i64 = image.width().into();
let h: i64 = image.height().into();
let c: i64 = P::CHANNEL_COUNT.into();
// Extra scope to ensure we eagerly drop the original image buffer
let pixels = {
let pixels: Vec<f32> = image.pixels().flat_map(|p| p.channels()).copied().collect();
Tensor::from_slice(&pixels)
};
pixels.to(device).reshape([h, w, c]).permute([2, 0, 1])
}
fn strip_transparency(image: DynamicImage, device: Device) -> Tensor {
let w: i64 = image.width().into();
let h: i64 = image.height().into();
match image {
DynamicImage::ImageRgb8(..)
| DynamicImage::ImageLuma8(..)
| DynamicImage::ImageLuma16(..)
| DynamicImage::ImageRgb16(..)
| DynamicImage::ImageRgb32F(..) => {
return into_tensor(image.into_rgb32f(), device);
}
_ => {}
};
// Get channels.
let (alpha, color) = {
let pixels = into_tensor(image.into_rgba32f(), device);
let alpha = pixels.slice(0, 3, 4, 1).broadcast_to([3, h, w]);
let color = pixels.slice(0, 0, 3, 1);
(alpha, color)
};
// Detect whether premultiplication should be applied by checking
// for channels with values above the alpha level.
//
// Note that the only input format which we can get where this would
// be relevant, PNG, explicitly says it does not carry premultiplied alpha,
// but many tools will store premultiplied alpha anyway...
let ones = Tensor::ones([3, h, w], (tch::Kind::Float, device));
let mask = alpha.where_self(&color.gt_tensor(&alpha).any(), &ones);
let color = color.multiply(&mask);
// Pure transparency is rescaled to be 8 steps blacker than black.
const ALPHA_LEVEL: f64 = 8.0 / 255.0;
const COLOR_LEVEL: f64 = 1.0 - ALPHA_LEVEL;
// Unwrap is guaranteed safe because the dimensions and data type match
color
.multiply_scalar(COLOR_LEVEL)
.f_add(&alpha.multiply_scalar(ALPHA_LEVEL))
.unwrap()
}
pub fn load_image<R>(image: R, device: Device) -> Result<Tensor, FeatureExtractionError>
where
R: std::io::Read + std::io::Seek,
{
let image = BufReader::new(image);
let image = ImageReader::new(image)
.with_guessed_format()
.map_err(|_| FeatureExtractionError::UnknownImageFormat)?
.decode()
.map_err(|_| FeatureExtractionError::ImageDecodeError)?;
Ok(strip_transparency(image, device))
}
fn resize_tensor(image: Tensor, size: (i64, i64)) -> Tensor {
image.upsample_bicubic2d([size.0, size.1], true, None, None)
}
pub fn resize_image_by_patch_count(image: Tensor, patches: (i64, i64), patch_dim: i64) -> Tensor {
let height = patches.0 * patch_dim;
let width = patches.1 * patch_dim;
resize_tensor(image.unsqueeze(0), (height, width))
}

View file

@ -1,12 +1,18 @@
use std::net::SocketAddr;
use std::sync::Arc;
use clap::Parser;
use dinov2::Executor;
use futures::{future, Future, StreamExt};
use mediaproc::{CommandReply, ExecuteCommandError, FileMap, MediaProcessor};
use mediaproc::{
CommandReply, ExecuteCommandError, FeatureExtractionError, FileMap, MediaProcessor,
};
use tarpc::context;
use tarpc::server::Channel;
mod command_server;
mod dinov2;
mod io;
mod signal;
#[derive(Parser, Debug)]
@ -14,10 +20,13 @@ mod signal;
struct Arguments {
/// Socket address to bind to, like 127.0.0.1:1500
server_addr: SocketAddr,
/// DINOv2 with registers base model to load.
model_path: String,
}
#[derive(Clone)]
struct MediaProcessorServer;
struct MediaProcessorServer(Arc<Executor>);
impl MediaProcessor for MediaProcessorServer {
async fn execute_command(
@ -29,14 +38,24 @@ impl MediaProcessor for MediaProcessorServer {
) -> Result<(CommandReply, FileMap), ExecuteCommandError> {
command_server::execute_command(program, arguments, file_map).await
}
async fn get_features(
self,
_: context::Context,
image: Vec<u8>,
) -> Result<Vec<f32>, FeatureExtractionError> {
self.0.extract(&image)
}
}
fn main() {
env_logger::init();
let args = Arguments::parse();
let executor = Executor::new(&args.model_path).expect("failed to load Torch JIT model");
let executor = Arc::new(executor);
serve(&args);
serve(&args, executor);
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
@ -44,7 +63,7 @@ async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
}
#[tokio::main]
async fn serve(args: &Arguments) {
async fn serve(args: &Arguments, executor: Arc<Executor>) {
signal::install_handlers();
let codec = tarpc::tokio_serde::formats::Bincode::default;
@ -57,12 +76,10 @@ async fn serve(args: &Arguments) {
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(tarpc::server::BaseChannel::with_defaults)
.map(|channel| {
tokio::spawn(
channel
.execute(MediaProcessorServer.serve())
.for_each(spawn),
);
.map(move |channel| {
let server = MediaProcessorServer(executor.clone());
tokio::spawn(channel.execute(server.serve()).for_each(spawn));
})
.collect()
.await

View file

@ -39,6 +39,12 @@ fn camo_image_url(input: &str) -> String {
// Remote NIF wrappers.
#[rustler::nif]
fn async_get_features(env: Env, server_addr: String, path: String) -> Atom {
let fut = remote::get_features(server_addr, path);
asyncnif::call_async(env, fut, remote::get_features_reply_with_env)
}
#[rustler::nif]
fn async_process_command(
env: Env,
@ -47,7 +53,7 @@ fn async_process_command(
arguments: Vec<String>,
) -> Atom {
let fut = remote::process_command(server_addr, program, arguments);
asyncnif::call_async(env, fut, remote::with_env)
asyncnif::call_async(env, fut, remote::command_reply_with_env)
}
// Zip NIF wrappers.

View file

@ -1,10 +1,13 @@
use mediaproc::client::{connect_to_socket_server, execute_command};
use mediaproc::CommandReply;
use mediaproc::client;
use mediaproc::{CommandReply, FeatureExtractionError};
use rustler::{atoms, Encoder, Env, NifStruct, OwnedBinary, Term};
atoms! {
nil,
command_reply,
ok,
error,
get_features_reply,
process_command_reply,
}
#[derive(NifStruct)]
@ -30,7 +33,7 @@ pub async fn process_command(
program: String,
arguments: Vec<String>,
) -> CommandReply {
let client = match connect_to_socket_server(&server_addr).await {
let client = match client::connect_to_socket_server(&server_addr).await {
Some(client) => client,
None => {
return CommandReply {
@ -41,7 +44,8 @@ pub async fn process_command(
}
};
match execute_command(&client, program, arguments).await {
let ctx = client::context_with_1_hour_deadline();
match client::execute_command(&client, program, arguments, ctx).await {
Ok(reply) => reply,
Err(err) => CommandReply {
stdout: vec![],
@ -51,11 +55,29 @@ pub async fn process_command(
}
}
/// Converts the response into a {:command_reply, %CommandReply{...}} message
/// which gets sent back to the caller.
pub fn with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> {
pub async fn get_features(
server_addr: String,
path: String,
) -> Result<Vec<f32>, FeatureExtractionError> {
let client = match client::connect_to_socket_server(&server_addr).await {
Some(client) => client,
None => return Err(FeatureExtractionError::ConnectionError),
};
let image = std::fs::read(path).map_err(|_| FeatureExtractionError::LocalFilesystemError)?;
let ctx = client::context_with_10_second_deadline();
client
.get_features(ctx, image)
.await
.map_err(|_| FeatureExtractionError::ConnectionError)?
}
/// Converts the response into a {:process_command_reply, %CommandReply{...}}
/// message which gets sent back to the caller.
pub fn command_reply_with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> {
(
command_reply(),
process_command_reply(),
CommandReply_ {
stdout: binary_or_nil(env, r.stdout),
stderr: binary_or_nil(env, r.stderr),
@ -64,3 +86,15 @@ pub fn with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> {
)
.encode(env)
}
/// Converts the response into a {:get_features_reply, {:ok, [0.1, ..., 0.1]}}
/// message which gets sent back to the caller.
pub fn get_features_reply_with_env<'a>(
env: Env<'a>,
r: Result<Vec<f32>, FeatureExtractionError>,
) -> Term<'a> {
match r {
Ok(features) => (get_features_reply(), (ok(), features)).encode(env),
Err(e) => (get_features_reply(), (error(), format!("{e:?}"))).encode(env),
}
}