From 2bd7ddf9d1440a0ab7b700518e423968c90b3c9d Mon Sep 17 00:00:00 2001 From: Liam <byteslice@airmail.cc> Date: Wed, 8 Jan 2025 21:35:50 -0500 Subject: [PATCH] Add feature extraction pipeline to mediaproc --- docker/mediaproc/Dockerfile | 6 +- lib/philomena/native.ex | 4 + lib/philomena_media/remote.ex | 14 +- native/philomena/Cargo.lock | 588 +++++++++++++++++- native/philomena/mediaproc/src/client.rs | 20 +- native/philomena/mediaproc/src/lib.rs | 18 + native/philomena/mediaproc_client/src/main.rs | 43 +- native/philomena/mediaproc_server/Cargo.toml | 2 + .../philomena/mediaproc_server/src/dinov2.rs | 106 ++++ native/philomena/mediaproc_server/src/io.rs | 100 +++ native/philomena/mediaproc_server/src/main.rs | 37 +- native/philomena/src/lib.rs | 8 +- native/philomena/src/remote.rs | 52 +- 13 files changed, 957 insertions(+), 41 deletions(-) create mode 100644 native/philomena/mediaproc_server/src/dinov2.rs create mode 100644 native/philomena/mediaproc_server/src/io.rs diff --git a/docker/mediaproc/Dockerfile b/docker/mediaproc/Dockerfile index 25b1748a..ebe094f3 100644 --- a/docker/mediaproc/Dockerfile +++ b/docker/mediaproc/Dockerfile @@ -62,14 +62,16 @@ RUN cd /tmp \ COPY native/philomena /tmp/philomena COPY docker/mediaproc/safe-rsvg-convert /usr/bin/safe-rsvg-convert +ADD https://github.com/liamwhite/philomena-ris-inference-toolkit/releases/download/v1.0/dinov2-with-registers-base.pt /usr/share/dinov2-with-registers-base.pt RUN cd /tmp/philomena \ && cargo build --release -p mediaproc_server \ - && cp target/release/mediaproc_server /usr/bin/mediaproc_server + && cp target/release/mediaproc_server /usr/bin/mediaproc_server \ + && find target/release/build -regextype posix-extended -regex '^.*\.so(\.[0-9]+)*$' -exec cp '{}' /usr/lib/ ';' # Set up unprivileged user account RUN useradd -ms /bin/bash mediaproc USER mediaproc WORKDIR /home/mediaproc ENV RUST_LOG=trace -CMD ["/usr/bin/mediaproc_server", "0.0.0.0:1500"] +CMD ["/usr/bin/mediaproc_server", "0.0.0.0:1500", "/usr/share/dinov2-with-registers-base.pt"] diff --git a/lib/philomena/native.ex b/lib/philomena/native.ex index b4f2f244..2d798fcd 100644 --- a/lib/philomena/native.ex +++ b/lib/philomena/native.ex @@ -12,6 +12,10 @@ defmodule Philomena.Native do @spec camo_image_url(String.t()) :: String.t() def camo_image_url(_uri), do: :erlang.nif_error(:nif_not_loaded) + @spec async_get_features(String.t(), String.t()) :: :ok + def async_get_features(_server_addr, _path), + do: :erlang.nif_error(:nif_not_loaded) + @spec async_process_command(String.t(), String.t(), [String.t()]) :: :ok def async_process_command(_server_addr, _program, _arguments), do: :erlang.nif_error(:nif_not_loaded) diff --git a/lib/philomena_media/remote.ex b/lib/philomena_media/remote.ex index 518a8613..8ff27fc3 100644 --- a/lib/philomena_media/remote.ex +++ b/lib/philomena_media/remote.ex @@ -7,11 +7,23 @@ defmodule PhilomenaMedia.Remote do :ok = Philomena.Native.async_process_command(mediaproc_addr(), command, args) receive do - {:command_reply, command_reply} -> + {:process_command_reply, command_reply} -> {command_reply.stdout, command_reply.status} end end + @doc """ + Gets a feature vector for the given image path to use in reverse image search. + """ + def get_features(path) do + :ok = Philomena.Native.async_get_features(mediaproc_addr(), path) + + receive do + {:get_features_reply, get_features_reply} -> + get_features_reply + end + end + defp mediaproc_addr do Application.get_env(:philomena, :mediaproc_addr) end diff --git a/native/philomena/Cargo.lock b/native/philomena/Cargo.lock index 9b7048c9..ec1ac581 100644 --- a/native/philomena/Cargo.lock +++ b/native/philomena/Cargo.lock @@ -17,6 +17,17 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -128,6 +139,18 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -137,12 +160,27 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bon" version = "3.3.2" @@ -174,18 +212,51 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytemuck" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" + [[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "caseless" version = "0.2.2" @@ -201,6 +272,8 @@ version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -210,6 +283,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.24" @@ -272,6 +355,21 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + +[[package]] +name = "cpufeatures" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -287,6 +385,22 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "darling" version = "0.20.10" @@ -322,6 +436,15 @@ dependencies = [ "syn", ] +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive_arbitrary" version = "1.4.1" @@ -339,6 +462,17 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "339544cc9e2c4dc3fc7149fd630c5f22263a4fdf18a98afd0075784968b5cf00" +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -433,6 +567,15 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "flate2" version = "1.0.35" @@ -547,6 +690,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -564,6 +717,16 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.15.2" @@ -576,6 +739,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "0.2.12" @@ -738,6 +910,20 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.25.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b" +dependencies = [ + "bytemuck", + "byteorder-lite", + "num-traits", + "png", + "zune-core", + "zune-jpeg", +] + [[package]] name = "indexmap" version = "2.7.0" @@ -748,6 +934,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "inventory" version = "0.3.17" @@ -789,6 +984,15 @@ dependencies = [ "libc", ] +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.76" @@ -855,6 +1059,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "mediaproc" version = "0.1.0" @@ -881,8 +1095,10 @@ dependencies = [ "clap", "env_logger", "futures", + "image", "mediaproc", "tarpc", + "tch", "tempfile", "tokio", "tracing", @@ -901,6 +1117,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -914,6 +1131,52 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.36.7" @@ -989,6 +1252,29 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -999,18 +1285,18 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" name = "philomena" version = "0.3.0" dependencies = [ - "base64", + "base64 0.21.7", "comrak", "http", "jemallocator", "mediaproc", "once_cell", "regex", - "ring", + "ring 0.16.20", "rustler", "tokio", "url", - "zip", + "zip 2.2.2", ] [[package]] @@ -1045,6 +1331,31 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "png" +version = "0.17.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1112,13 +1423,19 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags", + "bitflags 2.6.0", ] [[package]] @@ -1165,12 +1482,27 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", - "untrusted", + "spin 0.5.2", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.52.0", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1183,7 +1515,7 @@ version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ - "bitflags", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -1215,6 +1547,38 @@ dependencies = [ "syn", ] +[[package]] +name = "rustls" +version = "0.23.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" +dependencies = [ + "log", + "once_cell", + "ring 0.17.8", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring 0.17.8", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.19" @@ -1227,6 +1591,16 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -1265,6 +1639,28 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1336,6 +1732,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1354,6 +1756,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.95" @@ -1412,6 +1820,23 @@ dependencies = [ "syn", ] +[[package]] +name = "tch" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb3500c87ef72447c23b33ed6f15fac45a616b09bcac53e62e0e4386bddb3b9d" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "safetensors", + "thiserror 1.0.69", + "torch-sys", + "zip 0.6.6", +] + [[package]] name = "tempfile" version = "3.15.0" @@ -1476,6 +1901,25 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + [[package]] name = "tinystr" version = "0.7.6" @@ -1560,6 +2004,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "torch-sys" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61b87ed41261d4278060c3ba3e735c224687cf312403e4565f2ca75310279d73" +dependencies = [ + "anyhow", + "cc", + "libc", + "serde", + "serde_json", + "ureq", + "zip 0.6.6", +] + [[package]] name = "tracing" version = "0.1.41" @@ -1626,6 +2085,12 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" version = "1.0.14" @@ -1653,6 +2118,30 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "url", + "webpki-roots", +] + [[package]] name = "url" version = "2.5.4" @@ -1688,6 +2177,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1768,6 +2263,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1950,6 +2454,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + [[package]] name = "zerovec" version = "0.10.4" @@ -1972,6 +2482,26 @@ dependencies = [ "syn", ] +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + [[package]] name = "zip" version = "2.2.2" @@ -2002,3 +2532,47 @@ dependencies = [ "once_cell", "simd-adler32", ] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "zune-core" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" + +[[package]] +name = "zune-jpeg" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99a5bab8d7dedf81405c4bb1f2b83ea057643d9cb28778cea9eecddeedd2e028" +dependencies = [ + "zune-core", +] diff --git a/native/philomena/mediaproc/src/client.rs b/native/philomena/mediaproc/src/client.rs index f76ef6d5..90d86073 100644 --- a/native/philomena/mediaproc/src/client.rs +++ b/native/philomena/mediaproc/src/client.rs @@ -98,25 +98,29 @@ fn update_replacements( Ok(()) } -fn context_with_1_hour_deadline() -> Context { +pub fn context_with_deadline(secs_from_now: u64) -> Context { let mut context = Context::current(); - context.deadline = Instant::now() + Duration::from_secs(60 * 60); + context.deadline = Instant::now() + Duration::from_secs(secs_from_now); context } +pub fn context_with_1_hour_deadline() -> Context { + context_with_deadline(60 * 60) +} + +pub fn context_with_10_second_deadline() -> Context { + context_with_deadline(10) +} + pub async fn execute_command( client: &MediaProcessorClient, program: String, arguments: Vec<String>, + ctx: Context, ) -> Result<CommandReply, ExecuteCommandError> { let call_params = create_replacements(arguments.into_iter()); let (reply, file_map) = client - .execute_command( - context_with_1_hour_deadline(), - program, - call_params.arguments, - call_params.file_map, - ) + .execute_command(ctx, program, call_params.arguments, call_params.file_map) .await .map_err(|_| ExecuteCommandError::UnknownError)??; diff --git a/native/philomena/mediaproc/src/lib.rs b/native/philomena/mediaproc/src/lib.rs index 5dedfdbf..10faf7c4 100644 --- a/native/philomena/mediaproc/src/lib.rs +++ b/native/philomena/mediaproc/src/lib.rs @@ -12,11 +12,16 @@ pub trait MediaProcessor { arguments: Vec<String>, file_map: FileMap, ) -> Result<(CommandReply, FileMap), ExecuteCommandError>; + + /// Runs feature extraction on an image file bytes (PNG or JPEG). + async fn get_features(image: Vec<u8>) -> Result<Vec<f32>, FeatureExtractionError>; } /// Errors which can occur during command execution. #[derive(Debug, Deserialize, Serialize)] pub enum ExecuteCommandError { + /// Failed to connect to server. + ConnectionError, /// Requested program was not allowed to be executed. UnpermittedProgram(String), /// Failed to launch program. @@ -31,6 +36,19 @@ pub enum ExecuteCommandError { UnknownError, } +/// Errors which can occur during image feature extraction. +#[derive(Debug, Deserialize, Serialize)] +pub enum FeatureExtractionError { + /// Failed to connect to server. + ConnectionError, + /// Generic filesystem error. + LocalFilesystemError, + /// Unrecognized image format. + UnknownImageFormat, + /// Failed to decode the image. + ImageDecodeError, +} + /// Enumeration of permitted program names. pub static PERMITTED_PROGRAMS: Lazy<HashSet<&'static str>> = Lazy::new(|| { vec![ diff --git a/native/philomena/mediaproc_client/src/main.rs b/native/philomena/mediaproc_client/src/main.rs index f9e70502..22949a41 100644 --- a/native/philomena/mediaproc_client/src/main.rs +++ b/native/philomena/mediaproc_client/src/main.rs @@ -2,7 +2,7 @@ use std::io::Write; use std::process::ExitCode; use clap::{Parser, Subcommand}; -use mediaproc::client::{connect_to_socket_server, execute_command}; +use mediaproc::client; use mediaproc::MediaProcessorClient; #[derive(Parser, Debug)] @@ -28,12 +28,17 @@ enum InvocationType { /// Arguments to pass to program. args: Vec<String>, }, + /// Get DINOv2 features from the given image file (PNG or JPEG). + ExtractFeatures { + /// Filename to extract from. + file_name: String, + }, } #[tokio::main(flavor = "current_thread")] async fn main() -> ExitCode { let args = Arguments::parse(); - let client = connect_to_socket_server(&args.server_addr) + let client = client::connect_to_socket_server(&args.server_addr) .await .expect("failed to connect to server"); @@ -41,6 +46,9 @@ async fn main() -> ExitCode { InvocationType::ExecuteCommand { program, args } => { run_command_client(&client, program, args).await } + InvocationType::ExtractFeatures { file_name } => { + run_feature_extraction_client(&client, file_name).await + } } } @@ -49,7 +57,10 @@ async fn run_command_client( program: String, args: Vec<String>, ) -> ExitCode { - let reply = execute_command(client, program, args).await.unwrap(); + let ctx = client::context_with_1_hour_deadline(); + let reply = client::execute_command(client, program, args, ctx) + .await + .unwrap(); write_then_drop(std::io::stderr(), reply.stderr); write_then_drop(std::io::stdout(), reply.stdout); @@ -60,3 +71,29 @@ async fn run_command_client( fn write_then_drop(mut stream: impl Write, data: Vec<u8>) { stream.write_all(&data).unwrap() } + +async fn run_feature_extraction_client( + client: &MediaProcessorClient, + file_name: String, +) -> ExitCode { + let image = std::fs::read(file_name).unwrap(); + let features = client + .get_features(client::context_with_10_second_deadline(), image) + .await + .unwrap() + .unwrap(); + + // Manual intersperse implementation, until rust adds it properly + let mut started = false; + for component in features { + if started { + print!(" {}", component); + } else { + print!("{}", component); + started = true; + } + } + println!(); + + ExitCode::SUCCESS +} diff --git a/native/philomena/mediaproc_server/Cargo.toml b/native/philomena/mediaproc_server/Cargo.toml index 834dcd81..38df0a7a 100644 --- a/native/philomena/mediaproc_server/Cargo.toml +++ b/native/philomena/mediaproc_server/Cargo.toml @@ -7,8 +7,10 @@ edition = "2021" env_logger = "0.11" clap = { version = "4.5", features = ["derive"] } futures = "0.3" +image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } mediaproc = { path = "../mediaproc" } tarpc = { version = "0.35", features = ["full"] } +tch = { version = "0.18.1", features = ["download-libtorch"] } tempfile = "3" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" diff --git a/native/philomena/mediaproc_server/src/dinov2.rs b/native/philomena/mediaproc_server/src/dinov2.rs new file mode 100644 index 00000000..d6c65b9a --- /dev/null +++ b/native/philomena/mediaproc_server/src/dinov2.rs @@ -0,0 +1,106 @@ +use std::io::Cursor; +use tch::{CModule, Device, IValue, Tensor}; + +use super::io; +use crate::FeatureExtractionError; + +/// Each DINOv2 patch is 14x14 +pub const PATCH_DIM: i64 = 14; + +pub struct ModelResult { + pub patches: (i64, i64), + pub image: Tensor, + pub features: Tensor, + pub last_hidden_state: Tensor, +} + +fn infer(image: &Tensor, model: &CModule) -> (Tensor, Tensor) { + // These cases intentionally panic because their outputs depend on the model, + // not on the image file input, and invalid model format is not recoverable. + let output = model + .forward_is(&[IValue::Tensor(image.shallow_clone())]) + .unwrap(); + + let mut results = match output { + IValue::Tuple(elements) if elements.len() == 2 => elements, + _ => unreachable!("expected (last_hidden_state, pooler_output)"), + }; + + let mut results = results.drain(..); + + match (results.next(), results.next()) { + (Some(IValue::Tensor(last_hidden_state)), Some(IValue::Tensor(pooler_output))) => { + (last_hidden_state, pooler_output) + } + _ => unreachable!("expected 2-tuple of tensors"), + } +} + +fn scaled_result(pooler_output: &Tensor) -> Tensor { + let scaled_norm = pooler_output.norm().pow_tensor_scalar(-1); + pooler_output.multiply(&scaled_norm) +} + +pub fn get_model_result<R>( + image: R, + model: &CModule, + device: Device, +) -> Result<ModelResult, FeatureExtractionError> +where + R: std::io::Read + std::io::Seek, +{ + // Get image and and dimensions for calculation. + let image = io::load_image(image, device)?; + + // Features are unstable across different global scales, and + // somewhat stable across dimensional scales. + // + // Use 18 (252x252) instead of 16 (224x224) to produce a more detailed + // result and attention map at almost exactly the same computational cost. + // + // It is possible for highly non-square models to have meaningful feature extraction, + // but in practice it makes no difference identifying scales which keep the aspect + // ratio, and does not produce feature vectors which are similar enough to identify + // crops. + let image_scale = 1; + let patches = (18 * image_scale, 18 * image_scale); + + // Scale image into appropriate shape. + let image = io::resize_image_by_patch_count(image, patches, PATCH_DIM); + + // The pooler output is the [CLS] token generated by the model. + // It contains high-quality, robust features from the input image. + let (last_hidden_state, pooler_output) = infer(&image, model); + + Ok(ModelResult { + patches, + image, + features: scaled_result(&pooler_output.squeeze()), + last_hidden_state, + }) +} + +pub struct Executor { + device: Device, + model: CModule, +} + +impl Executor { + pub fn new(model_path: &str) -> Option<Self> { + let (device, model) = io::device_and_model(model_path)?; + Some(Self { device, model }) + } + + pub fn extract(&self, image: &[u8]) -> Result<Vec<f32>, FeatureExtractionError> { + let image = Cursor::new(image); + let model_result = get_model_result(image, &self.model, self.device)?; + let features = model_result + .features + .iter::<f64>() + .unwrap() + .map(|f| f as f32) + .collect(); + + Ok(features) + } +} diff --git a/native/philomena/mediaproc_server/src/io.rs b/native/philomena/mediaproc_server/src/io.rs new file mode 100644 index 00000000..43f8618b --- /dev/null +++ b/native/philomena/mediaproc_server/src/io.rs @@ -0,0 +1,100 @@ +use image::{DynamicImage, ImageBuffer, ImageReader, Pixel}; +use std::io::BufReader; +use tch::{CModule, Device, Tensor}; + +use crate::FeatureExtractionError; + +pub fn device_and_model(model_path: &str) -> Option<(Device, CModule)> { + let device = Device::cuda_if_available(); + let model = CModule::load_on_device(model_path, device).ok()?; + + Some((device, model)) +} + +fn into_tensor<P: Pixel<Subpixel = f32>>( + image: ImageBuffer<P, Vec<f32>>, + device: Device, +) -> Tensor { + let w: i64 = image.width().into(); + let h: i64 = image.height().into(); + let c: i64 = P::CHANNEL_COUNT.into(); + + // Extra scope to ensure we eagerly drop the original image buffer + let pixels = { + let pixels: Vec<f32> = image.pixels().flat_map(|p| p.channels()).copied().collect(); + + Tensor::from_slice(&pixels) + }; + + pixels.to(device).reshape([h, w, c]).permute([2, 0, 1]) +} + +fn strip_transparency(image: DynamicImage, device: Device) -> Tensor { + let w: i64 = image.width().into(); + let h: i64 = image.height().into(); + + match image { + DynamicImage::ImageRgb8(..) + | DynamicImage::ImageLuma8(..) + | DynamicImage::ImageLuma16(..) + | DynamicImage::ImageRgb16(..) + | DynamicImage::ImageRgb32F(..) => { + return into_tensor(image.into_rgb32f(), device); + } + _ => {} + }; + + // Get channels. + let (alpha, color) = { + let pixels = into_tensor(image.into_rgba32f(), device); + let alpha = pixels.slice(0, 3, 4, 1).broadcast_to([3, h, w]); + let color = pixels.slice(0, 0, 3, 1); + + (alpha, color) + }; + + // Detect whether premultiplication should be applied by checking + // for channels with values above the alpha level. + // + // Note that the only input format which we can get where this would + // be relevant, PNG, explicitly says it does not carry premultiplied alpha, + // but many tools will store premultiplied alpha anyway... + let ones = Tensor::ones([3, h, w], (tch::Kind::Float, device)); + let mask = alpha.where_self(&color.gt_tensor(&alpha).any(), &ones); + let color = color.multiply(&mask); + + // Pure transparency is rescaled to be 8 steps blacker than black. + const ALPHA_LEVEL: f64 = 8.0 / 255.0; + const COLOR_LEVEL: f64 = 1.0 - ALPHA_LEVEL; + + // Unwrap is guaranteed safe because the dimensions and data type match + color + .multiply_scalar(COLOR_LEVEL) + .f_add(&alpha.multiply_scalar(ALPHA_LEVEL)) + .unwrap() +} + +pub fn load_image<R>(image: R, device: Device) -> Result<Tensor, FeatureExtractionError> +where + R: std::io::Read + std::io::Seek, +{ + let image = BufReader::new(image); + let image = ImageReader::new(image) + .with_guessed_format() + .map_err(|_| FeatureExtractionError::UnknownImageFormat)? + .decode() + .map_err(|_| FeatureExtractionError::ImageDecodeError)?; + + Ok(strip_transparency(image, device)) +} + +fn resize_tensor(image: Tensor, size: (i64, i64)) -> Tensor { + image.upsample_bicubic2d([size.0, size.1], true, None, None) +} + +pub fn resize_image_by_patch_count(image: Tensor, patches: (i64, i64), patch_dim: i64) -> Tensor { + let height = patches.0 * patch_dim; + let width = patches.1 * patch_dim; + + resize_tensor(image.unsqueeze(0), (height, width)) +} diff --git a/native/philomena/mediaproc_server/src/main.rs b/native/philomena/mediaproc_server/src/main.rs index 05683c08..0466fac9 100644 --- a/native/philomena/mediaproc_server/src/main.rs +++ b/native/philomena/mediaproc_server/src/main.rs @@ -1,12 +1,18 @@ use std::net::SocketAddr; +use std::sync::Arc; use clap::Parser; +use dinov2::Executor; use futures::{future, Future, StreamExt}; -use mediaproc::{CommandReply, ExecuteCommandError, FileMap, MediaProcessor}; +use mediaproc::{ + CommandReply, ExecuteCommandError, FeatureExtractionError, FileMap, MediaProcessor, +}; use tarpc::context; use tarpc::server::Channel; mod command_server; +mod dinov2; +mod io; mod signal; #[derive(Parser, Debug)] @@ -14,10 +20,13 @@ mod signal; struct Arguments { /// Socket address to bind to, like 127.0.0.1:1500 server_addr: SocketAddr, + + /// DINOv2 with registers base model to load. + model_path: String, } #[derive(Clone)] -struct MediaProcessorServer; +struct MediaProcessorServer(Arc<Executor>); impl MediaProcessor for MediaProcessorServer { async fn execute_command( @@ -29,14 +38,24 @@ impl MediaProcessor for MediaProcessorServer { ) -> Result<(CommandReply, FileMap), ExecuteCommandError> { command_server::execute_command(program, arguments, file_map).await } + + async fn get_features( + self, + _: context::Context, + image: Vec<u8>, + ) -> Result<Vec<f32>, FeatureExtractionError> { + self.0.extract(&image) + } } fn main() { env_logger::init(); let args = Arguments::parse(); + let executor = Executor::new(&args.model_path).expect("failed to load Torch JIT model"); + let executor = Arc::new(executor); - serve(&args); + serve(&args, executor); } async fn spawn(fut: impl Future<Output = ()> + Send + 'static) { @@ -44,7 +63,7 @@ async fn spawn(fut: impl Future<Output = ()> + Send + 'static) { } #[tokio::main] -async fn serve(args: &Arguments) { +async fn serve(args: &Arguments, executor: Arc<Executor>) { signal::install_handlers(); let codec = tarpc::tokio_serde::formats::Bincode::default; @@ -57,12 +76,10 @@ async fn serve(args: &Arguments) { // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) .map(tarpc::server::BaseChannel::with_defaults) - .map(|channel| { - tokio::spawn( - channel - .execute(MediaProcessorServer.serve()) - .for_each(spawn), - ); + .map(move |channel| { + let server = MediaProcessorServer(executor.clone()); + + tokio::spawn(channel.execute(server.serve()).for_each(spawn)); }) .collect() .await diff --git a/native/philomena/src/lib.rs b/native/philomena/src/lib.rs index f2f5893e..0de54fe5 100644 --- a/native/philomena/src/lib.rs +++ b/native/philomena/src/lib.rs @@ -39,6 +39,12 @@ fn camo_image_url(input: &str) -> String { // Remote NIF wrappers. +#[rustler::nif] +fn async_get_features(env: Env, server_addr: String, path: String) -> Atom { + let fut = remote::get_features(server_addr, path); + asyncnif::call_async(env, fut, remote::get_features_reply_with_env) +} + #[rustler::nif] fn async_process_command( env: Env, @@ -47,7 +53,7 @@ fn async_process_command( arguments: Vec<String>, ) -> Atom { let fut = remote::process_command(server_addr, program, arguments); - asyncnif::call_async(env, fut, remote::with_env) + asyncnif::call_async(env, fut, remote::command_reply_with_env) } // Zip NIF wrappers. diff --git a/native/philomena/src/remote.rs b/native/philomena/src/remote.rs index e0cd4447..1d5533db 100644 --- a/native/philomena/src/remote.rs +++ b/native/philomena/src/remote.rs @@ -1,10 +1,13 @@ -use mediaproc::client::{connect_to_socket_server, execute_command}; -use mediaproc::CommandReply; +use mediaproc::client; +use mediaproc::{CommandReply, FeatureExtractionError}; use rustler::{atoms, Encoder, Env, NifStruct, OwnedBinary, Term}; atoms! { nil, - command_reply, + ok, + error, + get_features_reply, + process_command_reply, } #[derive(NifStruct)] @@ -30,7 +33,7 @@ pub async fn process_command( program: String, arguments: Vec<String>, ) -> CommandReply { - let client = match connect_to_socket_server(&server_addr).await { + let client = match client::connect_to_socket_server(&server_addr).await { Some(client) => client, None => { return CommandReply { @@ -41,7 +44,8 @@ pub async fn process_command( } }; - match execute_command(&client, program, arguments).await { + let ctx = client::context_with_1_hour_deadline(); + match client::execute_command(&client, program, arguments, ctx).await { Ok(reply) => reply, Err(err) => CommandReply { stdout: vec![], @@ -51,11 +55,29 @@ pub async fn process_command( } } -/// Converts the response into a {:command_reply, %CommandReply{...}} message -/// which gets sent back to the caller. -pub fn with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> { +pub async fn get_features( + server_addr: String, + path: String, +) -> Result<Vec<f32>, FeatureExtractionError> { + let client = match client::connect_to_socket_server(&server_addr).await { + Some(client) => client, + None => return Err(FeatureExtractionError::ConnectionError), + }; + + let image = std::fs::read(path).map_err(|_| FeatureExtractionError::LocalFilesystemError)?; + let ctx = client::context_with_10_second_deadline(); + + client + .get_features(ctx, image) + .await + .map_err(|_| FeatureExtractionError::ConnectionError)? +} + +/// Converts the response into a {:process_command_reply, %CommandReply{...}} +/// message which gets sent back to the caller. +pub fn command_reply_with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> { ( - command_reply(), + process_command_reply(), CommandReply_ { stdout: binary_or_nil(env, r.stdout), stderr: binary_or_nil(env, r.stderr), @@ -64,3 +86,15 @@ pub fn with_env<'a>(env: Env<'a>, r: CommandReply) -> Term<'a> { ) .encode(env) } + +/// Converts the response into a {:get_features_reply, {:ok, [0.1, ..., 0.1]}} +/// message which gets sent back to the caller. +pub fn get_features_reply_with_env<'a>( + env: Env<'a>, + r: Result<Vec<f32>, FeatureExtractionError>, +) -> Term<'a> { + match r { + Ok(features) => (get_features_reply(), (ok(), features)).encode(env), + Err(e) => (get_features_reply(), (error(), format!("{e:?}"))).encode(env), + } +}