Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new model ids 2 #929

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ async def infer_from_request(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
roboflow_instant_model = len(str(model_id).split("/")) == 1
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand All @@ -58,10 +61,13 @@ def infer_from_request_sync(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
roboflow_instant_model = len(str(model_id).split("/")) == 1
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand Down Expand Up @@ -196,10 +202,13 @@ def infer_from_request_sync(
prediction = super().infer_from_request_sync(
model_id=model_id, request=request, **kwargs
)
# TODO: active learning is disabled for instant models; to be enabled in the future
roboflow_instant_model = len(str(model_id).split("/")) == 1
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
Expand Down
7 changes: 6 additions & 1 deletion inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def __init__(
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)
self.dataset_id, self.version_id = model_id.split("/")
model_id_chunks = model_id.split("/")
if len(model_id_chunks) == 1:
self.dataset_id = model_id
self.version_id = None
else:
self.dataset_id, self.version_id = model_id.split("/")
self.endpoint = model_id
self.device_id = GLOBAL_DEVICE_ID
self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
Expand Down
7 changes: 6 additions & 1 deletion inference/core/models/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def __init__(self, model_id: str, api_key: str):
super().__init__()
self.model_id = model_id
self.api_key = api_key
self.dataset_id, self.version_id = model_id.split("/")
model_id_chunks = model_id.split("/")
if len(model_id_chunks) == 1:
self.dataset_id = model_id
self.version_id = None
else:
self.dataset_id, self.version_id = model_id.split("/")
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
initialise_cache(model_id=model_id)

Expand Down
38 changes: 26 additions & 12 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PROJECT_TASK_TYPE_KEY,
ModelEndpointType,
get_roboflow_dataset_type,
get_roboflow_instant_model_data,
get_roboflow_model_data,
get_roboflow_workspace,
)
Expand Down Expand Up @@ -115,12 +116,21 @@ def get_model_type(
model_type=model_type,
)
return project_task_type, model_type
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")

if version_id is None:
api_data = get_roboflow_instant_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")
else:
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")
if api_data is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
# some older projects do not have type field - hence defaulting
Expand All @@ -143,7 +153,7 @@ def get_model_type(


def get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: str, version_id: Optional[VersionID]
) -> Optional[Tuple[TaskType, ModelType]]:
if LAMBDA:
return _get_model_metadata_from_cache(
Expand All @@ -158,7 +168,7 @@ def get_model_metadata_from_cache(


def _get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: str, version_id: Optional[VersionID]
) -> Optional[Tuple[TaskType, ModelType]]:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
Expand Down Expand Up @@ -194,7 +204,7 @@ def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> b

def save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -220,7 +230,7 @@ def save_model_metadata_in_cache(

def _save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -236,6 +246,10 @@ def _save_model_metadata_in_cache(
)


def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str:
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id)
def construct_model_type_cache_path(
dataset_id: str, version_id: Optional[VersionID]
) -> str:
cache_dir = os.path.join(
MODEL_CACHE_DIR, dataset_id, version_id if version_id is not None else ""
)
return os.path.join(cache_dir, "model_type.json")
38 changes: 38 additions & 0 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,44 @@ def get_roboflow_model_data(
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_instant_model_data(
api_key: str,
model_id: str,
endpoint_type: ModelEndpointType,
device_id: str,
) -> dict:
api_data_cache_key = f"roboflow_api_data:{endpoint_type.value}:{model_id}"
api_data = cache.get(api_data_cache_key)
if api_data is not None:
logger.debug(f"Loaded model data from cache with key: {api_data_cache_key}.")
return api_data
else:
params = [
("nocache", "true"),
("device", device_id),
("dynamic", "true"),
("type", endpoint_type.value),
("model", model_id),
]
if api_key is not None:
params.append(("api_key", api_key))
api_url = _add_params_to_url(
url=f"{API_BASE_URL}/getWeights",
params=params,
)
api_data = _get_from_url(url=api_url)
cache.set(
api_data_cache_key,
api_data,
expire=10,
)
logger.debug(
f"Loaded model data from Roboflow API and saved to cache with key: {api_data_cache_key}."
)
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_base_lora(
api_key: str, repo: str, revision: str, device_id: str
Expand Down
2 changes: 2 additions & 0 deletions inference/core/utils/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]:
model_id_chunks = model_id.split("/")
if len(model_id_chunks) == 1:
return model_id, None
if len(model_id_chunks) != 2:
raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.")
return model_id_chunks[0], model_id_chunks[1]
21 changes: 15 additions & 6 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from inference.core.roboflow_api import (
ModelEndpointType,
get_from_url,
get_roboflow_instant_model_data,
get_roboflow_model_data,
)
from inference.core.utils.image_utils import (
Expand Down Expand Up @@ -736,12 +737,20 @@ def download_model_artefacts_from_s3(self):
raise NotImplementedError("Owlv2 not currently supported on hosted inference")

def download_model_artifacts_from_roboflow_api(self):
api_data = get_roboflow_model_data(
api_key=self.api_key,
model_id=self.endpoint,
endpoint_type=ModelEndpointType.OWLV2,
device_id=self.device_id,
)
if self.version_id is None:
api_data = get_roboflow_instant_model_data(
api_key=self.api_key,
model_id=self.endpoint,
endpoint_type=ModelEndpointType.OWLV2,
device_id=self.device_id,
)
else:
api_data = get_roboflow_model_data(
api_key=self.api_key,
model_id=self.endpoint,
endpoint_type=ModelEndpointType.OWLV2,
device_id=self.device_id,
)
api_data = api_data["owlv2"]
if "model" not in api_data:
raise ModelArtefactError(
Expand Down
4 changes: 3 additions & 1 deletion inference/usage_tracking/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,9 @@ def _extract_usage_params_from_func_kwargs(
elif "self" in func_kwargs:
_self = func_kwargs["self"]
if hasattr(_self, "dataset_id") and hasattr(_self, "version_id"):
model_id = f"{_self.dataset_id}/{_self.version_id}"
model_id = str(_self.dataset_id)
if _self.version_id:
model_id += f"/{_self.version_id}"
category = "model"
resource_id = model_id
elif isinstance(kwargs, dict) and "model_id" in kwargs:
Expand Down
11 changes: 10 additions & 1 deletion tests/inference/unit_tests/core/utils/test_roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@
from inference.core.utils.roboflow import get_model_id_chunks


@pytest.mark.parametrize("value", ["some", "some/2/invalid", "another-2"])
@pytest.mark.parametrize("value", ["some/2/invalid"])
def test_get_model_id_chunks_when_invalid_input_given(value: Any) -> None:
# when
with pytest.raises(InvalidModelIDError):
_ = get_model_id_chunks(model_id=value)


@pytest.mark.parametrize("value", ["some", "another-2"])
def test_get_model_id_chunks_when_instant_model_id_given(value: Any) -> None:
# when
result = get_model_id_chunks(model_id=value)

# then
assert result == (value, None)


def test_get_model_id_chunks_when_valid_input_given() -> None:
# when
result = get_model_id_chunks("some/1")
Expand Down
Loading