mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-03-12 04:36:10 -07:00
Pep8 formatting
This commit is contained in:
parent
01f1cbb629
commit
4f936c3049
python_backend
@ -39,4 +39,4 @@ config = template.substitute(
|
||||
use_auto_device_map=args.use_auto_device_map,
|
||||
)
|
||||
with open(model_dir_path/'../config.pbtxt', 'w') as f:
|
||||
f.write(config)
|
||||
f.write(config)
|
||||
|
@ -1,28 +1,33 @@
|
||||
import json
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from torch.utils.dlpack import to_dlpack, from_dlpack
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def pb2torch(request, name):
|
||||
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
||||
return from_dlpack(tensor.to_dlpack())
|
||||
|
||||
|
||||
def torch2pb(name, tensor):
|
||||
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
def initialize(self, args):
|
||||
self.model_config = model_config = json.loads(args["model_config"])
|
||||
org_name = model_config["parameters"].get("org_name", {"string_value": "Salesforce"})["string_value"]
|
||||
model_name = org_name + "/" + model_config["parameters"]["model_name"]["string_value"]
|
||||
|
||||
get_bool = lambda x: model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
||||
def get_bool(x):
|
||||
return model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
||||
|
||||
is_half = get_bool("use_half")
|
||||
int8 = get_bool("use_int8") # this will make inference marginally slower, but will allow bigger models to fit in GPU
|
||||
# This will make inference marginally slower, but will allow bigger models to fit in GPU
|
||||
int8 = get_bool("use_int8")
|
||||
auto_device_map = get_bool("use_auto_device_map")
|
||||
|
||||
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
||||
@ -37,8 +42,8 @@ class TritonPythonModel:
|
||||
print(f"Model {model_name} Loaded. Footprint: {self.model.get_memory_footprint()}")
|
||||
|
||||
# set max_batch_size
|
||||
self.max_batch_size = 0 # model_config["max_batch_size"]
|
||||
|
||||
self.max_batch_size = 0 # model_config["max_batch_size"]
|
||||
|
||||
def execute(self, requests):
|
||||
# TODO: don't just loop over requests. batch them up
|
||||
|
||||
@ -55,7 +60,7 @@ class TritonPythonModel:
|
||||
attention_mask = torch.zeros(input_ids_torch.shape, dtype=torch.long)
|
||||
for i, l in enumerate(input_lengths_torch):
|
||||
attention_mask[i, :l] = 1
|
||||
|
||||
|
||||
# Output length
|
||||
max_new_tokens = request_output_len_torch[0][0]
|
||||
|
||||
@ -71,9 +76,9 @@ class TritonPythonModel:
|
||||
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples,
|
||||
temperature=temperature,
|
||||
)
|
||||
# assert len(output_ids.shape) == 2, "huggingface format is batch x seq_len"
|
||||
# assert output_ids.shape[0] == input_ids_torch.shape[0], "expecting batch size to match input"
|
||||
output_ids = output_ids.unsqueeze(1) # client wants batch x beam_width x seq_len and we don't support beam_width yet
|
||||
|
||||
# client wants batch x beam_width x seq_len and we don't support beam_width yet
|
||||
output_ids = output_ids.unsqueeze(1)
|
||||
|
||||
# create output tensors
|
||||
out_tensor_pb = torch2pb("output_ids", output_ids)
|
||||
@ -88,4 +93,4 @@ class TritonPythonModel:
|
||||
response = pb_utils.InferenceResponse([out_tensor_pb, sequence_length_pb])
|
||||
responses.append(response)
|
||||
|
||||
return responses
|
||||
return responses
|
||||
|
Loading…
x
Reference in New Issue
Block a user