Source code for OpenAttack.data_manager
import os
from typing import Any, Optional
from .exceptions import UnknownDataException, DataNotExistException
from .data import data_list
import inspect
[docs]class DataManager(object):
"""
DataManager is a module that manages all the resources used in Attacker, Metric, Substitute, TextProcessors and utils.
It reads configuration files in OpenAttack/data/\*.py, and initialize these resources when you load them.
You can use
.. code-block:: python
for data_name in OpenAttack.DataManager.AVAILABLE_DATAS:
OpenAttack.download(data_name)
to download all the available resources, but this is not recommend because of the huge network cost.
``OpenAttack.load`` and ``OpenAttack.download`` is a alias of
``OpenAttack.DataManager.load`` and ``OpenAttack.DataManager.download``, they are exactly equivalent.
These two methods are useful for both developer and user, that's the reason we provide shortter name for them.
"""
AVAILABLE_DATAS = [x["name"] for x in data_list]
data_path = {
x["name"]: os.path.join(os.getcwd(), "data", x["name"]) for x in data_list
}
data_download = {x["name"]: x["download"] for x in data_list}
data_loader = {x["name"]: x["load"] for x in data_list}
data_reference = {kw: None for kw in AVAILABLE_DATAS}
__auto_download = True
source = "https://data.thunlp.org/"
def __init__(self):
raise NotImplementedError()
[docs] @classmethod
def enable_cdn(cls):
"""
Enable cdn for all official downloads.
"""
cls.source = "https://cdn.data.thunlp.org/"
[docs] @classmethod
def disable_cdn(cls):
"""
Disable cdn for all official downloads.
"""
cls.source = "https://data.thunlp.org/"
[docs] @classmethod
def load(cls, data_name : str, cached : bool = True) -> Any:
"""
Load data from local storage, and download it automatically if not exists.
Args:
data_name: The name of resource that you want to load. You can find all the available resource names in ``DataManager.AVAILABLE_DATAS``. *Note: all the names are* **CASE-SENSITIVE**.
cached: If **cached** is *True*, DataManager will lookup the cache before load it to avoid duplicate disk IO. If **cached** is *False*, DataManager will directly load data from disk. **Default:** *True*.
Returns:
The object returned by LOAD function of corresponding data.
Raises:
UnknownDataException: For loading an unavailable data.
DataNotExistException: For loading a data that has not been downloaded. This appends when AutoDownload mechanism is disabled.
"""
if data_name not in cls.AVAILABLE_DATAS:
raise UnknownDataException()
if not os.path.exists(cls.data_path[data_name]):
if cls.__auto_download:
cls.download(data_name)
else:
raise DataNotExistException(data_name, cls.data_path[data_name])
if not cached:
return cls.data_loader[data_name](cls.data_path[data_name])
elif cls.data_reference[data_name] is None:
try:
cls.data_reference[data_name] = cls.data_loader[data_name](
cls.data_path[data_name]
)
except OSError:
raise DataNotExistException(data_name, cls.data_path[data_name])
return cls.data_reference[data_name]
[docs] @classmethod
def loadVictim(cls, data_name, cached=True):
"""
This method is equivalent to ``DataManager.load("Victim." + data_name)``.
"""
return cls.load("Victim." + data_name, cached=cached)
[docs] @classmethod
def loadTProcess(cls, data_name, cached=True):
"""
This method is equivalent to ``DataManager.load("TProcess." + data_name)``.
"""
return cls.load("TProcess." + data_name, cached=cached)
[docs] @classmethod
def loadAttackAssist(cls, data_name, cached=True):
"""
This method is equivalent to ``DataManager.load("AttackAssist." + data_name)``.
"""
return cls.load("AttackAssist." + data_name, cached=cached)
[docs] @classmethod
def setAutoDownload(cls, enabled : bool = True):
"""
AutoDownload mechanism is enabled by default.
Args:
enabled: Change if DataManager automatically download the data when loading.
"""
cls.__auto_download = enabled
[docs] @classmethod
def get(cls, data_name : str) -> str:
"""
Args:
data_name: The name of data.
Returns:
Relative path of data.
"""
if data_name not in cls.AVAILABLE_DATAS:
raise UnknownDataException
return cls.data_path[data_name]
[docs] @classmethod
def set_path(cls, path : str, data_name : Optional[str] = None):
"""Set the path for a specific data or for all data.
If **data_name** is *None*, all paths will be changed to corresponding file under **path** directory.
If **data_name** is *not None*, the specific data path will be changed to **path**.
The default paths for all data are ``./data/<data_name>``, and you can manually change them using this method .
Args:
path: The path to data, or path to the directory where all data is stored.
data_name: The name of data. If **data_name** is *None*, all paths will be changed.
"""
if data_name is None:
nw_dict = {}
for kw, pt in cls.data_path.items():
nw_dict[kw] = os.path.join(path, os.path.basename(pt))
cls.data_path = nw_dict
else:
if data_name not in cls.AVAILABLE_DATAS:
raise UnknownDataException
cls.data_path[data_name] = path
[docs] @classmethod
def download(cls, data_name : str, path : Optional[str] = None, force : bool = False):
"""
This method will check if data exists before getting it from "Data Server".You can use **force** to skip this step.
Args:
data_name: Name of the data that you want to download.
path: Specify a path when before download. Leaves None for download to default **path**.
force: Force download the data.
Raises:
UnknownDataException: For downloading an unavailable data.
"""
if data_name not in cls.AVAILABLE_DATAS:
raise UnknownDataException()
if path is None:
path = cls.data_path[data_name]
if os.path.exists(path) and not force:
return True
download_func = cls.data_download[data_name]
parent_dir = os.path.dirname(path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
num_args = len(inspect.getfullargspec(download_func).args)
if num_args == 1:
download_func(path)
elif num_args == 2:
download_func(path, cls.source)
return True