Friday, March 14, 2008

Lua 的 Dispatcher

-----------------------------------------------------------------------------
-- A hacked dispatcher module
-- LuaSocket sample files
-- Author: Diego Nehab
-- RCS ID: $$
-----------------------------------------------------------------------------
local base = _G
local table = require("table")
local socket = require("socket")
local coroutine = require("coroutine")
module("dispatch")

-- if too much time goes by without any activity in one of our sockets, we
-- just kill it
TIMEOUT = 60

-----------------------------------------------------------------------------
-- We implement 3 types of dispatchers:
-- sequential
-- coroutine
-- threaded
-- The user can choose whatever one is needed
-----------------------------------------------------------------------------
local handlert = {}

-- default handler is coroutine
function newhandler(mode)
mode = mode or "coroutine"
return handlert[mode]()
end

local function seqstart(self, func)
return func()
end

-- sequential handler simply calls the functions and doesn't wrap I/O
function handlert.sequential() -- 注意 t={} t.a == t["a"]
return {
tcp = socket.tcp,
start = seqstart
}
end

-----------------------------------------------------------------------------
-- Mega hack. Don't try to do this at home.
-----------------------------------------------------------------------------
-- we can't yield across calls to protect, so we rewrite it with coxpcall
-- make sure you don't require any module that uses socket.protect before
-- loading our hack
function socket.protect(f)
return function(...)
local co = coroutine.create(f)
while true do
local results = {coroutine.resume(co, base.unpack(arg))}
local status = table.remove(results, 1)
if not status then
if type(results[1]) == 'table' then
return nil, results[1][1]
else base.error(results[1]) end
end
if coroutine.status(co) == "suspended" then
arg = {coroutine.yield(base.unpack(results))}
else
return base.unpack(results)
end
end
end
end

-----------------------------------------------------------------------------
-- Simple set data structure. O(1) everything.
-----------------------------------------------------------------------------
local function newset()
local reverse = {}
local set = {}
return base.setmetatable(set, {__index = {
insert = function(set, value)
if not reverse[value] then
table.insert(set, value)
reverse[value] = table.getn(set)
end
end,
remove = function(set, value)
local index = reverse[value]
if index then
reverse[value] = nil
local top = table.remove(set)
if top ~= value then
reverse[top] = index
set[index] = top
end
end
end
}})
end

-----------------------------------------------------------------------------
-- socket.tcp() wrapper for the coroutine dispatcher
-----------------------------------------------------------------------------
local function cowrap(dispatcher, tcp, error)
if not tcp then return nil, error end
-- put it in non-blocking mode right away
tcp:settimeout(0) -- 超时时间设置成0 就是 non-blocking mode
-- metatable for wrap produces new methods on demand for those that we
-- don't override explicitly.
local metat = { __index = function(table, key)
table[key] = function(...)
arg[1] = tcp
return tcp[key](base.unpack(arg))
end
return table[key]
end}
-- does our user want to do his own non-blocking I/O?
local zero = false -- 注意这是一个局部变量
-- create a wrap object that will behave just like a real socket object
local wrap = { }
-- we ignore settimeout to preserve our 0 timeout, but record whether
-- the user wants to do his own non-blocking I/O
function wrap:settimeout(value, mode) -- : 定义了一个隐式的self参数
if value == 0 then zero = true
else zero = false end
return 1
end
-- send in non-blocking mode and yield on timeout
function wrap:send(data, first, last)
first = (first or 1) - 1
local result, error
while true do
-- return control to dispatcher and tell it we want to send
-- if upon return the dispatcher tells us we timed out,
-- return an error to whoever called us
if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
return nil, "timeout"
end
-- try sending
result, error, first = tcp:send(data, first+1, last) -- 这里的三个变量都是局部的
-- if we are done, or there was an unexpected error,
-- break away from loop
if error ~= "timeout" then return result, error, first end
end
end
-- receive in non-blocking mode and yield on timeout
-- or simply return partial read, if user requested timeout = 0
function wrap:receive(pattern, partial)
local error = "timeout"
local value
while true do
-- return control to dispatcher and tell it we want to receive
-- if upon return the dispatcher tells us we timed out,
-- return an error to whoever called us
if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
return nil, "timeout"
end
-- try receiving
value, error, partial = tcp:receive(pattern, partial)
-- if we are done, or there was an unexpected error,
-- break away from loop. also, if the user requested
-- zero timeout, return all we got
if (error ~= "timeout") or zero then
return value, error, partial
end
end
end
-- connect in non-blocking mode and yield on timeout
function wrap:connect(host, port)
local result, error = tcp:connect(host, port)
if error == "timeout" then
-- return control to dispatcher. we will be writable when
-- connection succeeds.
-- if upon return the dispatcher tells us we have a
-- timeout, just abort
if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
return nil, "timeout"
end
-- when we come back, check if connection was successful
result, error = tcp:connect(host, port)
if result or error == "already connected" then return 1
else return nil, "non-blocking connect failed" end
else return result, error end
end
-- accept in non-blocking mode and yield on timeout
function wrap:accept()
while 1 do
-- return control to dispatcher. we will be readable when a
-- connection arrives.
-- if upon return the dispatcher tells us we have a
-- timeout, just abort
if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
return nil, "timeout"
end
local client, error = tcp:accept()
if error ~= "timeout" then
return cowrap(dispatcher, client, error)
end
end
end
-- remove cortn from context
function wrap:close()
dispatcher.stamp[tcp] = nil
dispatcher.sending.set:remove(tcp)
dispatcher.sending.cortn[tcp] = nil
dispatcher.receiving.set:remove(tcp)
dispatcher.receiving.cortn[tcp] = nil
return tcp:close()
end
return base.setmetatable(wrap, metat)
end


-----------------------------------------------------------------------------
-- Our coroutine dispatcher
-----------------------------------------------------------------------------
local cometat = { __index = {} }

function schedule(cortn, status, operation, tcp)
if status then
if cortn and operation then
operation.set:insert(tcp)
operation.cortn[tcp] = cortn
operation.stamp[tcp] = socket.gettime()
end
else base.error(operation) end
end

function kick(operation, tcp)
operation.cortn[tcp] = nil
operation.set:remove(tcp)
end

function wakeup(operation, tcp)
local cortn = operation.cortn[tcp]
-- if cortn is still valid, wake it up
if cortn then
kick(operation, tcp)
return cortn, coroutine.resume(cortn)
-- othrewise, just get scheduler not to do anything
else
return nil, true
end
end

function abort(operation, tcp)
local cortn = operation.cortn[tcp]
if cortn then
kick(operation, tcp)
coroutine.resume(cortn, "timeout")
end
end

-- step through all active cortns
function cometat.__index:step()
-- check which sockets are interesting and act on them
local readable, writable = socket.select(self.receiving.set,
self.sending.set, 1)
-- for all readable connections, resume their cortns and reschedule
-- when they yield back to us
for _, tcp in base.ipairs(readable) do
schedule(wakeup(self.receiving, tcp))
end
-- for all writable connections, do the same
for _, tcp in base.ipairs(writable) do
schedule(wakeup(self.sending, tcp))
end
-- politely ask replacement I/O functions in idle cortns to
-- return reporting a timeout
local now = socket.gettime()
for tcp, stamp in base.pairs(self.stamp) do
if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then
abort(self.sending, tcp)
abort(self.receiving, tcp)
end
end
end

function cometat.__index:start(func)
local cortn = coroutine.create(func)
schedule(cortn, coroutine.resume(cortn))
end

function handlert.coroutine() -- 协程式分派器
local stamp = {}
local dispatcher = {
stamp = stamp,
sending = {
name = "sending",
set = newset(),
cortn = {},
stamp = stamp
},
receiving = {
name = "receiving",
set = newset(),
cortn = {},
stamp = stamp
},
}
function dispatcher.tcp()
return cowrap(dispatcher, socket.tcp())
end
return base.setmetatable(dispatcher, cometat) -- 这种写法非常的 Lua 化,这里,返回一个 local 的表
end

-- check-links.lua

-----------------------------------------------------------------------------
-- Little program that checks links in HTML files, using coroutines and
-- non-blocking I/O via the dispatcher module.
-- LuaSocket sample files
-- Author: Diego Nehab
-- RCS ID: $$
-----------------------------------------------------------------------------
local url = require("socket.url")
local dispatch = require("dispatch")
local http = require("socket.http")
dispatch.TIMEOUT = 10

-- make sure the user knows how to invoke us
arg = arg or {} -- 每一个lua文件加载后都是一个代码块,也就是一个函数
if table.getn(arg) < 1 then
print("Usage:n luasocket check-links.lua [-n] {<url>}")
exit()
end

-- '-n' means we are running in non-blocking mode
if arg[1] == "-n" then
-- if non-blocking I/O was requested, use real dispatcher interface
table.remove(arg, 1)
handler = dispatch.newhandler("coroutine") -- 获得一个 协程式 分派器
else
-- if using blocking I/O, use fake dispatcher interface
handler = dispatch.newhandler("sequential")
end

local nthreads = 0

-- get the status of a URL using the dispatcher
function getstatus(link)
local parsed = url.parse(link, {scheme = "file"})
if parsed.scheme == "http" then
nthreads = nthreads + 1
handler:start(function()
local r, c, h, s = http.request{
method = "HEAD",
url = link,
create = handler.tcp
} -- 函数调用,这里省去了圆括号
if r and c == 200 then io.write('t', link, 'n')
else io.write('t', link, ': ', tostring(c), 'n') end
nthreads = nthreads - 1
end
)
end
end

function readfile(path)
path = url.unescape(path)
local file, error = io.open(path, "r")
if file then
local body = file:read("*a")
file:close()
return body
else return nil, error end
end

function load(u)
local parsed = url.parse(u, { scheme = "file" })
local body, headers, code, error
local base = u
if parsed.scheme == "http" then
body, code, headers = http.request(u)
if code == 200 then
-- if there was a redirect, update base to reflect it
base = headers.location or base
end
if not body then
error = code
end
elseif parsed.scheme == "file" then
body, error = readfile(parsed.path)
else error = string.format("unhandled scheme '%s'", parsed.scheme) end
return base, body, error
end

function getlinks(body, base)
-- get rid of comments
body = string.gsub(body, "%<%!%-%-.-%-%-%>", "")
local links = {}
-- extract links
body = string.gsub(body, '[Hh][Rr][Ee][Ff]%s*=%s*"([^"]*)"', function(href)
table.insert(links, url.absolute(base, href))
end)
body = string.gsub(body, "[Hh][Rr][Ee][Ff]%s*=%s*'([^']*)'", function(href)
table.insert(links, url.absolute(base, href))
end)
string.gsub(body, "[Hh][Rr][Ee][Ff]%s*=%s*(.-)>", function(href)
table.insert(links, url.absolute(base, href))
end)
return links
end

function checklinks(address)
local base, body, error = load(address)
if not body then print(error) return end
print("Checking ", base)
local links = getlinks(body, base)
for _, link in ipairs(links) do
getstatus(link)
end
end

for _, address in ipairs(arg) do
checklinks(url.absolute("file:", address))
end

while nthreads > 0 do
handler:step()
end