PinkAlpaca commited on
Commit
6d3fc7f
·
verified ·
1 Parent(s): 72bedc9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +249 -0
  2. requirements.txt +6 -7
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Optional, Tuple, Dict, Any
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast
7
+ import torch
8
+
9
+ # Configure logging
10
+ logging.basicConfig(
11
+ level=logging.DEBUG, # Set to DEBUG for detailed logs
12
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13
+ handlers=[
14
+ logging.FileHandler("app.log"), # Log to file
15
+ logging.StreamHandler() # Log to console
16
+ ]
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Initialize FastAPI app
22
+ app = FastAPI()
23
+
24
+ # Environment configuration
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ model_id = "google/gemma-2-2b-it"
27
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16
32
+ )
33
+ model.config.sliding_window = 4096
34
+ model.eval()
35
+
36
+ # Constants
37
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
38
+ MAX_MAX_NEW_TOKENS = 2048
39
+ DEFAULT_MAX_NEW_TOKENS = 1024
40
+
41
+ # Data model for API request
42
+ class Item(BaseModel):
43
+ input: Optional[str] = None
44
+ system_prompt: Optional[str] = None
45
+ system_output: Optional[str] = None
46
+ history: Optional[List[Tuple[str, str]]] = None
47
+ templates: Optional[List[str]] = None
48
+ temperature: float = 0.6
49
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
50
+ top_p: float = 0.9
51
+ repetition_penalty: float = 1.2
52
+
53
+ # Function to generate the response
54
+ def generate_response(item: Item) -> Dict[str, Any]:
55
+ logger.debug(f"Received request: {item}")
56
+
57
+ conversation = []
58
+ if item.history:
59
+ for user, assistant in item.history:
60
+ conversation.extend([
61
+ {"role": "user", "content": user},
62
+ {"role": "assistant", "content": assistant},
63
+ ])
64
+ conversation.append({"role": "user", "content": item.input})
65
+
66
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
67
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
+
70
+ input_ids = input_ids.to(device)
71
+ generate_kwargs = dict(
72
+ input_ids=input_ids,
73
+ max_new_tokens=item.max_new_tokens,
74
+ do_sample=True,
75
+ top_p=item.top_p,
76
+ temperature=item.temperature,
77
+ num_beams=1,
78
+ repetition_penalty=item.repetition_penalty,
79
+ )
80
+
81
+ try:
82
+ logger.debug("Starting text generation")
83
+ output = model.generate(**generate_kwargs, return_dict_in_generate=True)
84
+ decoded_output = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
85
+ logger.debug("Text generation successful")
86
+ return {"output": decoded_output}
87
+ except Exception as e:
88
+ logger.error(f"Error during text generation: {str(e)}", exc_info=True)
89
+ raise HTTPException(status_code=500, detail=str(e))
90
+
91
+ # Endpoint for generating text
92
+ @app.post("/")
93
+ async def generate_text(item: Item):
94
+ logger.info("Processing request")
95
+ if item.input is None and item.system_prompt is None:
96
+ logger.warning("Missing required parameters")
97
+ raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
98
+
99
+ response = generate_response(item)
100
+ logger.info("Request processed successfully")
101
+ return response
102
+
103
+ # Run the app
104
+ if __name__ == "__main__":
105
+ import uvicorn
106
+ logger.info("Starting server")
107
+ uvicorn.run(app, host="0.0.0.0", port=8000)
108
+
109
+ '''import os
110
+ from threading import Thread
111
+ from typing import Iterator
112
+
113
+ import gradio as gr
114
+ import spaces
115
+ import torch
116
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
117
+
118
+ DESCRIPTION = """\
119
+ # Gemma 2 2B IT
120
+
121
+ Gemma 2 is Google's latest iteration of open LLMs.
122
+ This is a demo of [`google/gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it), fine-tuned for instruction following.
123
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
124
+
125
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it) and the 9B version in [this Space](https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it).
126
+ """
127
+
128
+ MAX_MAX_NEW_TOKENS = 2048
129
+ DEFAULT_MAX_NEW_TOKENS = 1024
130
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
131
+
132
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
+
134
+ model_id = "google/gemma-2-2b-it"
135
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
136
+ model = AutoModelForCausalLM.from_pretrained(
137
+ model_id,
138
+ device_map="auto",
139
+ torch_dtype=torch.bfloat16,
140
+ )
141
+ model.config.sliding_window = 4096
142
+ model.eval()
143
+
144
+
145
+ @spaces.GPU(duration=90)
146
+ def generate(
147
+ message: str,
148
+ chat_history: list[tuple[str, str]],
149
+ max_new_tokens: int = 1024,
150
+ temperature: float = 0.6,
151
+ top_p: float = 0.9,
152
+ top_k: int = 50,
153
+ repetition_penalty: float = 1.2,
154
+ ) -> Iterator[str]:
155
+ conversation = []
156
+ for user, assistant in chat_history:
157
+ conversation.extend(
158
+ [
159
+ {"role": "user", "content": user},
160
+ {"role": "assistant", "content": assistant},
161
+ ]
162
+ )
163
+ conversation.append({"role": "user", "content": message})
164
+
165
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
166
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
167
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
168
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
169
+ input_ids = input_ids.to(model.device)
170
+
171
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
172
+ generate_kwargs = dict(
173
+ {"input_ids": input_ids},
174
+ streamer=streamer,
175
+ max_new_tokens=max_new_tokens,
176
+ do_sample=True,
177
+ top_p=top_p,
178
+ top_k=top_k,
179
+ temperature=temperature,
180
+ num_beams=1,
181
+ repetition_penalty=repetition_penalty,
182
+ )
183
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
184
+ t.start()
185
+
186
+ outputs = []
187
+ for text in streamer:
188
+ outputs.append(text)
189
+ yield "".join(outputs)
190
+
191
+
192
+ chat_interface = gr.ChatInterface(
193
+ fn=generate,
194
+ additional_inputs=[
195
+ gr.Slider(
196
+ label="Max new tokens",
197
+ minimum=1,
198
+ maximum=MAX_MAX_NEW_TOKENS,
199
+ step=1,
200
+ value=DEFAULT_MAX_NEW_TOKENS,
201
+ ),
202
+ gr.Slider(
203
+ label="Temperature",
204
+ minimum=0.1,
205
+ maximum=4.0,
206
+ step=0.1,
207
+ value=0.6,
208
+ ),
209
+ gr.Slider(
210
+ label="Top-p (nucleus sampling)",
211
+ minimum=0.05,
212
+ maximum=1.0,
213
+ step=0.05,
214
+ value=0.9,
215
+ ),
216
+ gr.Slider(
217
+ label="Top-k",
218
+ minimum=1,
219
+ maximum=1000,
220
+ step=1,
221
+ value=50,
222
+ ),
223
+ gr.Slider(
224
+ label="Repetition penalty",
225
+ minimum=1.0,
226
+ maximum=2.0,
227
+ step=0.05,
228
+ value=1.2,
229
+ ),
230
+ ],
231
+ stop_btn=None,
232
+ examples=[
233
+ ["Hello there! How are you doing?"],
234
+ ["Can you explain briefly to me what is the Python programming language?"],
235
+ ["Explain the plot of Cinderella in a sentence."],
236
+ ["How many hours does it take a man to eat a Helicopter?"],
237
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
238
+ ],
239
+ cache_examples=False,
240
+ )
241
+
242
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
243
+ gr.Markdown(DESCRIPTION)
244
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
245
+ chat_interface.render()
246
+
247
+ if __name__ == "__main__":
248
+ demo.queue(max_size=20).launch()
249
+ '''
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
- fastapi
2
- uvicorn
3
- huggingface_hub
4
- pydantic
5
- google-cloud-aiplatform
6
- requests
7
- transformers
 
1
+ accelerate==0.33.0
2
+ bitsandbytes==0.43.2
3
+ gradio==4.39.0
4
+ spaces==0.29.2
5
+ torch==2.2.0
6
+ transformers==4.43.3