215 lines
8.6 KiB
Python
215 lines
8.6 KiB
Python
import time
|
|
import re
|
|
import sys
|
|
|
|
import requests
|
|
from requests_oauthlib import OAuth2Session
|
|
from oauthlib.oauth2 import LegacyApplicationClient
|
|
from oauthlib.oauth2.rfc6749.errors import InvalidGrantError
|
|
|
|
from mirri.entities.strain import ValidationError
|
|
|
|
|
|
class BiolomicsClient:
|
|
schemas = None
|
|
allowed_fields = None
|
|
|
|
def __init__(self, server_url, api_version, client_id, client_secret,
|
|
username, password, website_id=1, verbose=False):
|
|
self._client_id = client_id
|
|
self._client_secret = client_secret
|
|
self._username = username
|
|
self._password = password
|
|
self._client = None
|
|
self.server_url = server_url
|
|
self._api_version = api_version
|
|
self._auth_url = self.server_url + "/connect/token"
|
|
self.access_token = None
|
|
self.website_id = website_id
|
|
self._verbose = verbose
|
|
self._schema = self.get_schemas()
|
|
|
|
def get_access_token(self):
|
|
if self._client is None:
|
|
self._client = LegacyApplicationClient(client_id=self._client_id)
|
|
authenticated = False
|
|
else:
|
|
expires_at = self._client.token["expires_at"]
|
|
authenticated = expires_at > time.time()
|
|
if not authenticated:
|
|
oauth = OAuth2Session(client=self._client)
|
|
try:
|
|
token = oauth.fetch_token(
|
|
token_url=self._auth_url,
|
|
username=self._username,
|
|
password=self._password,
|
|
client_id=self._client_id,
|
|
client_secret=self._client_secret,
|
|
)
|
|
except InvalidGrantError:
|
|
oauth.close()
|
|
raise
|
|
self.access_token = token["access_token"]
|
|
oauth.close()
|
|
return self.access_token
|
|
|
|
def _build_headers(self):
|
|
self.get_access_token()
|
|
return {
|
|
"accept": "application/json",
|
|
"websiteId": str(self.website_id),
|
|
"Authorization": f"Bearer {self.access_token}",
|
|
}
|
|
|
|
def get_detail_url(self, end_point, record_id, api_version=None):
|
|
# api_version = self._api_version if api_version is None else api_version
|
|
if api_version:
|
|
return "/".join([self.server_url, api_version, 'data',
|
|
end_point, str(record_id)])
|
|
else:
|
|
return "/".join([self.server_url, 'data', end_point, str(record_id)])
|
|
|
|
def get_list_url(self, end_point):
|
|
return "/".join([self.server_url, 'data', end_point])
|
|
# return "/".join([self.server_url, self._api_version, 'data', end_point])
|
|
|
|
def get_search_url(self, end_point):
|
|
return "/".join([self.server_url, self._api_version, 'search', end_point])
|
|
|
|
def get_find_by_name_url(self, end_point):
|
|
return "/".join([self.get_search_url(end_point), 'findByName'])
|
|
|
|
def search(self, end_point, search_query):
|
|
self._check_end_point_exists(end_point)
|
|
header = self._build_headers()
|
|
url = self.get_search_url(end_point)
|
|
time0 = time.time()
|
|
response = requests.post(url, json=search_query, headers=header)
|
|
time1 = time.time()
|
|
if self._verbose:
|
|
sys.stdout.write(f'Search to {end_point} request time for {url}: {time1 - time0}\n')
|
|
return response
|
|
|
|
def retrieve(self, end_point, record_id):
|
|
self._check_end_point_exists(end_point)
|
|
header = self._build_headers()
|
|
url = self.get_detail_url(end_point, record_id, api_version=self._api_version)
|
|
time0 = time.time()
|
|
response = requests.get(url, headers=header)
|
|
time1 = time.time()
|
|
if self._verbose:
|
|
sys.stdout.write(f'Get to {end_point} request time for {url}: {time1-time0}\n')
|
|
return response
|
|
|
|
def create(self, end_point, data):
|
|
self._check_end_point_exists(end_point)
|
|
self._check_data_consistency(data, self.allowed_fields[end_point])
|
|
header = self._build_headers()
|
|
url = self.get_list_url(end_point)
|
|
return requests.post(url, json=data, headers=header)
|
|
|
|
def update(self, end_point, record_id, data):
|
|
self._check_end_point_exists(end_point)
|
|
self._check_data_consistency(data, self.allowed_fields[end_point],
|
|
update=True)
|
|
header = self._build_headers()
|
|
url = self.get_detail_url(end_point, record_id=record_id)
|
|
return requests.put(url, json=data, headers=header)
|
|
|
|
def delete(self, end_point, record_id):
|
|
self._check_end_point_exists(end_point)
|
|
header = self._build_headers()
|
|
url = self.get_detail_url(end_point, record_id)
|
|
return requests.delete(url, headers=header)
|
|
|
|
def find_by_name(self, end_point, name):
|
|
self._check_end_point_exists(end_point)
|
|
header = self._build_headers()
|
|
url = self.get_find_by_name_url(end_point)
|
|
response = requests.get(url, headers=header, params={'name': name})
|
|
return response
|
|
|
|
def get_schemas(self):
|
|
if self.schemas is None:
|
|
headers = self._build_headers()
|
|
url = self.server_url + '/schemas'
|
|
response = requests.get(url, headers=headers)
|
|
if response.status_code == 200:
|
|
self.schemas = response.json()
|
|
else:
|
|
raise ValueError(f"{response.status_code}: {response.text}")
|
|
if self.allowed_fields is None:
|
|
self.allowed_fields = self._process_schema(self.schemas)
|
|
return self.schemas
|
|
|
|
@staticmethod
|
|
def _process_schema(schemas):
|
|
schema = schemas[0]
|
|
allowed_fields = {}
|
|
for endpoint_schema in schema['TableViews']:
|
|
endpoint_name = endpoint_schema['TableViewName']
|
|
endpoint_values = endpoint_schema['ResultFields']
|
|
fields = {field['title']: field for field in endpoint_values}
|
|
allowed_fields[endpoint_name] = fields
|
|
return allowed_fields
|
|
|
|
def _check_end_point_exists(self, endpoint):
|
|
if endpoint not in self.allowed_fields.keys():
|
|
raise ValueError(f'{endpoint} not a recognised endpoint')
|
|
|
|
def _check_data_consistency(self, data, allowed_fields, update=False):
|
|
update_mandatory = set(['RecordDetails', 'RecordName', 'RecordId'])
|
|
if update and not update_mandatory.issubset(data.keys()):
|
|
msg = 'Updating data keys must be RecordDetails, RecordName and RecordId'
|
|
raise ValidationError(msg)
|
|
|
|
if not update and set(data.keys()).difference(['RecordDetails', 'RecordName', 'Acronym']):
|
|
msg = 'data keys must be RecordDetails and RecordName or Acronym'
|
|
raise ValidationError(msg)
|
|
for field_name, field_value in data['RecordDetails'].items():
|
|
if field_name not in allowed_fields:
|
|
raise ValidationError(f'{field_name} not in allowed fields')
|
|
|
|
field_schema = allowed_fields[field_name]
|
|
self._check_field_schema(field_name, field_schema, field_value)
|
|
|
|
@staticmethod
|
|
def _check_field_schema(field_name, field_schema, field_value):
|
|
if field_schema['FieldType'] != field_value['FieldType']:
|
|
msg = f"Bad FieldType ({field_value['FieldType']}) for {field_name}. "
|
|
msg += f"It should be {field_schema['FieldType']}"
|
|
raise ValidationError(msg)
|
|
|
|
states = field_schema['states'] if 'states' in field_schema else None
|
|
if states:
|
|
states = [re.sub(r" *\(.*\)", "", s) for s in states]
|
|
|
|
subfields = field_schema['subfields'] if 'subfields' in field_schema else None
|
|
if subfields is not None and states is not None:
|
|
subfield_names = [subfield['SubFieldName']
|
|
for subfield in subfields if subfield['IsUsed']]
|
|
|
|
for val in field_value['Value']:
|
|
if val['Name'] not in subfield_names:
|
|
msg = f"{field_name}: {val['Name']} not in {subfield_names}"
|
|
raise ValidationError(msg)
|
|
|
|
if val['Value'] not in states:
|
|
|
|
msg = f"{field_value['Value']} not a valid value for "
|
|
msg += f"{field_name}, Allowed values: {'. '.join(states)}"
|
|
raise ValidationError(msg)
|
|
|
|
elif states is not None:
|
|
if field_value['Value'] not in states:
|
|
msg = f"{field_value['Value']} not a valid value for "
|
|
msg += f"{field_name}, Allowed values: {'. '.join(states)}"
|
|
raise ValidationError(msg)
|
|
|
|
def rollback(self, created_ids):
|
|
for endpoint, id_ in created_ids:
|
|
try:
|
|
self.delete(end_point=endpoint, record_id=id_)
|
|
except Exception:
|
|
pass
|