Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions label_studio_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" .. include::../docs/client.md
"""
import requests
import warnings
import logging
import requests

from typing import Optional
from pydantic import BaseModel, constr, root_validator
from requests.adapters import HTTPAdapter

logger = logging.getLogger(__name__)
Expand All @@ -12,9 +15,21 @@
HEADERS = {}


class ClientCredentials(BaseModel):
email: Optional[str]
password: Optional[str]
api_key: Optional[constr(regex=r"^[a-zA-Z0-9]{40}$")] = None

@root_validator(pre=True)
def either_key_or_email_password(cls, values):
assert 'email' in values or 'api_key' in values, 'At least one of email or api_key should be included'
assert 'email' not in values or 'password' in values, 'Provide both email and password for login auth'
return values


class Client(object):

def __init__(self, url, api_key, session=None, extra_headers: dict = None):
def __init__(self, url, api_key, credentials=None, session=None, extra_headers: dict = None):
""" Initialize the client. Do this before using other Label Studio SDK classes and methods in your script.

Parameters
Expand All @@ -24,17 +39,36 @@ def __init__(self, url, api_key, session=None, extra_headers: dict = None):
Example: http://localhost:8080
api_key: str
User token for the API. You can find this on your user account page in Label Studio.
credentials: ClientCredentials
User email and password or api_key.
session: requests.Session()
If None, a new one is created.
extra_headers: dict
Additional headers that will be passed to each http request
"""
self.url = url.rstrip('/')
self.api_key = api_key
self.session = session or self.get_session()

# set headers
self.headers = {'Authorization': f'Token {self.api_key}'}
if extra_headers:
self.headers.update(extra_headers)

# set api key or get it using credentials (username and password)
if api_key is not None:
warnings.warn("A deprecation warning to fit accordingly to your deprecation policy", DeprecationWarning)
credentials = ClientCredentials(api_key=api_key)
self.api_key = credentials.api_key if credentials.api_key else self.get_api_key(credentials)

def get_api_key(self, credentials: ClientCredentials):
login_url = self.get_url("/user/login")
# Retrieve and set the CSRF token first
self.session.get(login_url)
csrf_token = self.session.cookies.get('csrftoken', None)
login_data = dict(**credentials.dict(), csrfmiddlewaretoken=csrf_token)
self.session.post(login_url, data=login_data, headers=dict(Referer=self.url)).raise_for_status()
api_key = self.session.get(self.get_url("/api/current-user/token")).json().get("token")
return api_key

def check_connection(self):
""" Call Label Studio /health endpoint to check the connection to the server.
Expand Down